# RPT - Teste de Esparsidade no BitNet 2B (Microsoft)
#
# Objetivo: validar se esparsidade melhora/mantem qualidade em modelo ternario.
# Base: RPT_BitNet_Microsoft.ipynb (que ja funciona).
#
# IMPORTANTE: Rode Cell 1, depois REINICIE o runtime, depois rode Cell 1 de novo e continue.

In [None]:
# CELL 1: SETUP (identico ao notebook Microsoft que funcionou)
# Na primeira vez, o pip install roda. Reinicie o runtime e rode esta cell de novo.
!pip install -q torch torchvision
!pip install -q git+https://github.com/huggingface/transformers.git accelerate

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    mem = torch.cuda.get_device_properties(0).total_memory
    print('VRAM: {:.1f} GB'.format(mem / 1e9))
    print('BF16 suportado:', torch.cuda.is_bf16_supported())

In [None]:
# CELL 2: CARREGAR MODELO (identico ao notebook Microsoft que funcionou)
MODEL_ID = 'microsoft/bitnet-b1.58-2B-4T-bf16'

print('Carregando tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print('Carregando modelo BitNet 2B...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map='auto'
)
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print('Parametros: {:,.0f} ({:.1f}B)'.format(n_params, n_params / 1e9))
if device.type == 'cuda':
    print('VRAM usada: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))
print('Pronto!')

In [None]:
# CELL 3: ANALISAR DISTRIBUICAO DOS PESOS
# Antes de aplicar esparsidade, entender o que temos

print('=== DISTRIBUICAO DOS PESOS ===')
print()

total_params = 0
total_zeros = 0
total_near_zero = 0
weight_layers = []

for name, param in model.named_parameters():
    if param.dim() < 2:
        continue  # pular biases e norms
    w = param.data.float()
    n = w.numel()
    zeros = (w == 0).sum().item()
    near_zero = (w.abs() < 0.01).sum().item()
    total_params += n
    total_zeros += zeros
    total_near_zero += near_zero
    weight_layers.append((name, n, zeros, w.abs().mean().item()))

print('Total parametros (matrices): {:,.0f}'.format(total_params))
print('Zeros exatos: {:,.0f} ({:.1f}%)'.format(total_zeros, 100 * total_zeros / total_params))
print('Perto de zero (<0.01): {:,.0f} ({:.1f}%)'.format(total_near_zero, 100 * total_near_zero / total_params))
print()

# Mostrar distribuicao de uma camada exemplo
print('--- Exemplo: primeiras 5 camadas de peso ---')
for name, n, zeros, mean_abs in weight_layers[:5]:
    print('  {} | params={:,} | zeros={:.1f}% | abs_mean={:.6f}'.format(
        name[:60], n, 100 * zeros / n, mean_abs))

# Valores unicos de uma camada
print()
print('--- Valores unicos (primeira camada linear) ---')
for name, param in model.named_parameters():
    if 'weight' in name and param.dim() == 2 and 'embed' not in name:
        w = param.data.float().flatten()
        uniques = torch.unique(w)
        if len(uniques) <= 20:
            print('  {} tem {} valores unicos: {}'.format(
                name[:50], len(uniques), uniques.tolist()))
        else:
            print('  {} tem {} valores unicos (min={:.4f}, max={:.4f})'.format(
                name[:50], len(uniques), w.min().item(), w.max().item()))
        break

print()
print('Esparsidade NATURAL do modelo: {:.1f}%'.format(100 * total_zeros / total_params))

In [None]:
# CELL 4: FUNCOES DE TESTE
# Mesmas funcoes do notebook Microsoft + perplexity

TEST_PROMPTS = [
    'The capital of France is',
    'Water boils at',
    'The largest planet in the solar system is',
    'Python is a programming language that',
    'In 1969, humans first',
]

EVAL_TEXTS = [
    'The capital of France is Paris, which is known for the Eiffel Tower.',
    'Water boils at 100 degrees Celsius at standard atmospheric pressure.',
    'The largest planet in the solar system is Jupiter, a gas giant.',
    'Python is a programming language that is widely used for data science.',
    'In 1969, humans first landed on the Moon during the Apollo 11 mission.',
    'Machine learning is a subset of artificial intelligence that focuses on learning from data.',
    'The speed of light in vacuum is approximately 299,792,458 meters per second.',
    'DNA contains the genetic instructions for the development of all living organisms.',
]

def compute_perplexity(model, tokenizer, texts):
    """Calcula perplexity media sobre uma lista de textos."""
    total_loss = 0.0
    total_tokens = 0
    for text in texts:
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs['input_ids'])
        n_tokens = inputs['input_ids'].shape[1] - 1
        total_loss += outputs.loss.item() * n_tokens
        total_tokens += n_tokens
    avg_loss = total_loss / total_tokens
    return torch.exp(torch.tensor(avg_loss)).item()

def test_generation(model, tokenizer, prompts):
    """Testa geracao e retorna textos."""
    results = []
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=30,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        results.append(text)
    return results

def chat(user_msg, system_msg='You are a helpful AI assistant.', max_tokens=200):
    messages = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    input_len = inputs['input_ids'].shape[-1]
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(out[0][input_len:], skip_special_tokens=True)

print('Funcoes de teste definidas.')

In [None]:
# CELL 5: BASELINE (0% esparsidade adicional)

print('=== BASELINE: MODELO ORIGINAL (0% esparsidade adicional) ===')
print()

ppl_baseline = compute_perplexity(model, tokenizer, EVAL_TEXTS)
print('Perplexity baseline: {:.2f}'.format(ppl_baseline))
print()

print('--- Geracao ---')
baseline_texts = test_generation(model, tokenizer, TEST_PROMPTS)
for text in baseline_texts:
    print('  {}'.format(text))
print()

print('--- Chat ---')
resp = chat('What is the capital of France?')
print('  Q: What is the capital of France?')
print('  A: {}'.format(resp))

In [None]:
# CELL 6: APLICAR ESPARSIDADE PROGRESSIVA
# v3: com debug detalhado para acompanhar progresso

import time

# Salvar estado original dos pesos NA CPU
print('Salvando pesos originais na CPU...')
original_state = {}
for name, param in model.named_parameters():
    if param.dim() >= 2 and 'embed' not in name and 'lm_head' not in name:
        original_state[name] = param.data.cpu().clone()
print('Salvos {} tensores de peso.'.format(len(original_state)))

torch.cuda.empty_cache()
if device.type == 'cuda':
    print('VRAM apos salvar: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))


def apply_sparsity(model, sparsity_pct, original_state, verbose=False):
    """Aplica esparsidade por magnitude usando amostragem para o threshold."""
    if sparsity_pct <= 0:
        if verbose:
            print('  Restaurando pesos originais...')
        for name, param in model.named_parameters():
            if name in original_state:
                param.data.copy_(original_state[name].to(param.device))
        if verbose:
            print('  Restaurado.')
        return 0.0

    # PASSO 1: Amostrar magnitudes na CPU
    if verbose:
        print('  [1/3] Amostrando magnitudes de {} camadas...'.format(len(original_state)))
    t0 = time.time()
    samples = []
    total_nonzero = 0
    for name in original_state:
        w = original_state[name].float()
        nonzero = w[w != 0].abs()
        n = nonzero.numel()
        if n == 0:
            continue
        total_nonzero += n
        if n > 50000:
            idx = torch.randperm(n)[:50000]
            samples.append(nonzero[idx])
        else:
            samples.append(nonzero)

    sample = torch.cat(samples)
    if verbose:
        print('  [1/3] Amostrados {:,} valores de {:,} nao-zeros ({:.1f}s)'.format(
            len(sample), total_nonzero, time.time() - t0))

    # PASSO 2: Calcular threshold
    if verbose:
        print('  [2/3] Calculando threshold para {}%...'.format(sparsity_pct))
    t1 = time.time()
    sample_sorted = sample.sort().values
    idx = int(len(sample_sorted) * sparsity_pct / 100.0)
    idx = min(idx, len(sample_sorted) - 1)
    threshold = sample_sorted[idx].item()
    del sample, sample_sorted, samples
    if verbose:
        print('  [2/3] Threshold: {:.6f} ({:.1f}s)'.format(threshold, time.time() - t1))

    # PASSO 3: Aplicar mascara camada por camada
    if verbose:
        print('  [3/3] Aplicando mascara em {} camadas...'.format(len(original_state)))
    t2 = time.time()
    total_pruned = 0
    total_weights = 0
    count = 0
    for name, param in model.named_parameters():
        if name not in original_state:
            continue
        w = original_state[name].clone()
        mask = w.float().abs() > threshold
        w[~mask] = 0
        param.data.copy_(w.to(param.device))
        total_pruned += (~mask).sum().item()
        total_weights += w.numel()
        count += 1
        if verbose and count % 50 == 0:
            print('    ... {}/{} camadas'.format(count, len(original_state)))

    actual = 100.0 * total_pruned / total_weights
    if verbose:
        print('  [3/3] Feito: {:,} de {:,} zerados ({:.1f}%) em {:.1f}s'.format(
            total_pruned, total_weights, actual, time.time() - t2))
    return actual


print('Funcao de esparsidade pronta.')
print()
print('Teste rapido: aplicar 30% com debug...')
s = apply_sparsity(model, 30, original_state, verbose=True)
print()
print('Restaurando...')
apply_sparsity(model, 0, original_state, verbose=True)
print('OK!')

In [None]:
# CELL 7: TESTE PROGRESSIVO DE ESPARSIDADE
# Testar em: 0%, 10%, 20%, 30%, 50%, 70%, 80%, 90%

sparsity_levels = [0, 10, 20, 30, 50, 70, 80, 90]
results = []

print('=== TESTE PROGRESSIVO DE ESPARSIDADE ===')
print('Testando {} niveis...'.format(len(sparsity_levels)))
print()

for target_sp in sparsity_levels:
    print('--- Esparsidade: {}% ---'.format(target_sp))
    t0 = time.time()

    # Aplicar esparsidade (restaura originais primeiro)
    actual_sp = apply_sparsity(model, target_sp, original_state)

    # Medir perplexity
    ppl = compute_perplexity(model, tokenizer, EVAL_TEXTS)

    # Testar geracao (3 prompts pra ser rapido)
    texts = test_generation(model, tokenizer, TEST_PROMPTS[:3])

    dt = time.time() - t0

    # Avaliar qualidade do texto
    coherent = True
    for t in texts:
        words = t.split()
        if len(words) > 3:
            # Detectar repeticao excessiva
            last_words = words[-6:]
            if len(set(last_words)) <= 2:
                coherent = False
                break

    result = {
        'target_sparsity': target_sp,
        'actual_sparsity': actual_sp,
        'perplexity': ppl,
        'coherent': coherent,
        'sample_text': texts[0],
        'time_sec': dt,
    }
    results.append(result)

    status = 'OK' if coherent else 'DEGRADOU'
    print('  Esparsidade real: {:.1f}% | PPL: {:.2f} | {} | {:.1f}s'.format(
        actual_sp, ppl, status, dt))
    print('  Exemplo: {}'.format(texts[0][:100]))
    print()

# Restaurar modelo original
apply_sparsity(model, 0, original_state)
print('Modelo restaurado ao original.')

In [None]:
# CELL 8: TABELA DE RESULTADOS

print('=' * 70)
print('RESULTADOS: ESPARSIDADE NO BITNET 2B')
print('=' * 70)
print()
print('{:<12} {:<12} {:<10} {:<10} {}'.format(
    'Esparsidade', 'Real', 'PPL', 'Status', 'Exemplo'))
print('-' * 70)

for r in results:
    status = 'OK' if r['coherent'] else 'QUEBROU'
    sample = r['sample_text'][:40].replace('\n', ' ')
    print('{:<12} {:<12} {:<10} {:<10} {}'.format(
        '{}%'.format(r['target_sparsity']),
        '{:.1f}%'.format(r['actual_sparsity']),
        '{:.2f}'.format(r['perplexity']),
        status,
        sample))

print()
print('--- ANALISE ---')
baseline_ppl = results[0]['perplexity']
print('PPL baseline (0%): {:.2f}'.format(baseline_ppl))

# Achar limite de esparsidade (onde PPL dobra ou texto quebra)
best_sparse = None
for r in results:
    if r['coherent'] and r['perplexity'] < baseline_ppl * 1.5:
        best_sparse = r

if best_sparse and best_sparse['target_sparsity'] > 0:
    print('Melhor esparsidade com qualidade: {}% (PPL {:.2f}, {:.1f}% do baseline)'.format(
        best_sparse['target_sparsity'],
        best_sparse['perplexity'],
        100 * best_sparse['perplexity'] / baseline_ppl))
    print()
    print('CONCLUSAO: BitNet 2B tolera {}% de esparsidade adicional!'.format(
        best_sparse['target_sparsity']))
else:
    print('CONCLUSAO: Esparsidade degrada qualidade rapidamente neste modelo.')

# Verificar se esparsidade MELHORA (como vimos no RPT_BitNet_Projeto)
improved = [r for r in results if r['perplexity'] < baseline_ppl and r['target_sparsity'] > 0]
if improved:
    best = min(improved, key=lambda x: x['perplexity'])
    print()
    print('*** ESPARSIDADE MELHOROU! ***')
    print('Melhor: {}% esparsidade -> PPL {:.2f} (era {:.2f}, melhoria de {:.1f}%)'.format(
        best['target_sparsity'],
        best['perplexity'],
        baseline_ppl,
        100 * (1 - best['perplexity'] / baseline_ppl)))

In [None]:
# CELL 9: TESTE DETALHADO DO MELHOR NIVEL
# Aplica o melhor nivel de esparsidade e faz testes completos

# Escolher o melhor nivel (mais esparso que ainda funciona)
baseline_ppl = results[0]['perplexity']
coherent_results = [r for r in results if r['coherent']]
good = [r for r in coherent_results
        if r['perplexity'] < baseline_ppl * 1.5 and r['target_sparsity'] > 0]

if good:
    best_level = max(good, key=lambda x: x['target_sparsity'])['target_sparsity']
else:
    best_level = 10  # fallback conservador

print('=== TESTE DETALHADO: {}% ESPARSIDADE ==='.format(best_level))
print()

apply_sparsity(model, best_level, original_state)

# Teste de geracao completo
print('--- Completar frase ---')
for prompt in TEST_PROMPTS:
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=30,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    print('  {}'.format(text))
print()

# Teste de chat
print('--- Chat ---')
perguntas_chat = [
    'What is the capital of France?',
    'Explain what BitNet is in 2 sentences.',
    'Write a Python function that checks if a number is prime.',
]
for p in perguntas_chat:
    print('User:', p)
    print('BitNet ({}% sparse):'.format(best_level), chat(p))
    print('-' * 60)

# Restaurar
apply_sparsity(model, 0, original_state)
print()
print('Modelo restaurado.')

In [None]:
# CELL 10: SALVAR RESULTADOS
import json

report = {
    'model': MODEL_ID,
    'date': '2026-02-06',
    'experiment': 'sparsity_on_bitnet_2b',
    'baseline_ppl': results[0]['perplexity'],
    'results': [],
}

for r in results:
    report['results'].append({
        'target_sparsity': r['target_sparsity'],
        'actual_sparsity': r['actual_sparsity'],
        'perplexity': r['perplexity'],
        'coherent': r['coherent'],
        'sample_text': r['sample_text'],
        'time_sec': r['time_sec'],
    })

with open('sparsity_bitnet_results.json', 'w') as f:
    json.dump(report, f, indent=2)

print('Resultados salvos em sparsity_bitnet_results.json')
print()
print('=== RESUMO FINAL ===')
print('Modelo: {}'.format(MODEL_ID))
print('PPL baseline: {:.2f}'.format(results[0]['perplexity']))
print()
for r in results:
    delta = r['perplexity'] - results[0]['perplexity']
    sign = '+' if delta > 0 else ''
    status = 'OK' if r['coherent'] else 'QUEBROU'
    print('  {:>3}% esparso -> PPL {:.2f} ({}{:.2f}) [{}]'.format(
        r['target_sparsity'], r['perplexity'], sign, delta, status))