# RPT - Pruning Progressivo com Fine-Tune no BitNet 2B
#
# Objetivo: empurrar esparsidade alem de 10% usando ciclos de poda+treino.
# Base: RPT_BitNet_Microsoft.ipynb + RPT_BitNet_Sparsity_Test.ipynb
#
# Resultado anterior: 10% magnitude pruning cru -> PPL 6.94 (-26%)
# Meta: chegar a 30-50%+ mantendo qualidade.
#
# IMPORTANTE: Rode Cell 1, REINICIE o runtime, rode Cell 1 de novo e continue.

In [None]:
# CELL 1: SETUP
!pip install -q torch torchvision
!pip install -q git+https://github.com/huggingface/transformers.git accelerate datasets

import torch
import time
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    mem = torch.cuda.get_device_properties(0).total_memory
    print('VRAM: {:.1f} GB'.format(mem / 1e9))

In [None]:
# CELL 2: CARREGAR MODELO (H100 otimizado)
MODEL_ID = 'microsoft/bitnet-b1.58-2B-4T-bf16'

# Otimizacoes H100
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

print('Carregando tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print('Carregando modelo BitNet 2B...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map='auto'
)

# torch.compile para otimizar kernels na H100
print('Compilando modelo (torch.compile)...')
model = torch.compile(model)

n_params = sum(p.numel() for p in model.parameters())
print('Parametros: {:,.0f} ({:.1f}B)'.format(n_params, n_params / 1e9))
if device.type == 'cuda':
    print('VRAM usada: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))
print('TF32 ativado: {}'.format(torch.backends.cuda.matmul.allow_tf32))
print('Pronto!')

In [None]:
# CELL 3: CARREGAR DATASET
from datasets import load_dataset

print('Carregando WikiText-2...')
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')

# seq_len=128 (rapido por step)
SEQ_LEN = 128
print('Tokenizando (seq_len={})...'.format(SEQ_LEN))
all_ids = []
for example in dataset:
    text = example['text'].strip()
    if len(text) < 20:
        continue
    ids = tokenizer.encode(text, add_special_tokens=False)
    all_ids.extend(ids)

chunks = []
for i in range(0, len(all_ids) - SEQ_LEN, SEQ_LEN):
    chunks.append(torch.tensor(all_ids[i:i + SEQ_LEN], dtype=torch.long))

print('Tokens totais: {:,}'.format(len(all_ids)))
print('Chunks de {}: {:,}'.format(SEQ_LEN, len(chunks)))

val_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation')
val_ids = []
for example in val_dataset:
    text = example['text'].strip()
    if len(text) < 20:
        continue
    val_ids.extend(tokenizer.encode(text, add_special_tokens=False))

val_chunks = []
for i in range(0, len(val_ids) - SEQ_LEN, SEQ_LEN):
    val_chunks.append(torch.tensor(val_ids[i:i + SEQ_LEN], dtype=torch.long))

print('Val chunks: {:,}'.format(len(val_chunks)))
print('Dataset pronto.')

In [None]:
# CELL 4: FUNCOES (H100 - velocidade maxima)

import random

TEST_PROMPTS = [
    'The capital of France is',
    'Water boils at',
    'The largest planet in the solar system is',
]


def compute_ppl(model, val_chunks, max_batches=50):
    """Perplexity no validation set."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    batches = val_chunks[:max_batches]
    for chunk in batches:
        input_ids = chunk.unsqueeze(0).to(model.device if hasattr(model, 'device') else 'cuda')
        with torch.no_grad():
            out = model(input_ids=input_ids, labels=input_ids)
        total_loss += out.loss.item() * (chunk.shape[0] - 1)
        total_tokens += chunk.shape[0] - 1
    return torch.exp(torch.tensor(total_loss / total_tokens)).item()


def test_gen(model, tokenizer, prompts):
    """Gera texto e retorna lista."""
    model.eval()
    results = []
    for p in prompts:
        inp = tokenizer(p, return_tensors='pt').to(model.device if hasattr(model, 'device') else 'cuda')
        with torch.no_grad():
            out = model.generate(
                **inp, max_new_tokens=25,
                do_sample=False, pad_token_id=tokenizer.eos_token_id)
        results.append(tokenizer.decode(out[0], skip_special_tokens=True))
    return results


def prune_magnitude(model, sparsity_pct, protect_layers=None):
    """Poda por magnitude global. Retorna esparsidade real."""
    if protect_layers is None:
        protect_layers = ['embed', 'lm_head']

    samples = []
    total_nonzero = 0
    for name, param in model.named_parameters():
        if param.dim() < 2:
            continue
        if any(p in name for p in protect_layers):
            continue
        w = param.data.cpu().float()
        nonzero = w[w != 0].abs()
        n = nonzero.numel()
        if n == 0:
            continue
        total_nonzero += n
        if n > 50000:
            idx = torch.randperm(n)[:50000]
            samples.append(nonzero[idx])
        else:
            samples.append(nonzero)

    sample = torch.cat(samples)
    sample_sorted = sample.sort().values
    idx = min(int(len(sample_sorted) * sparsity_pct / 100.0), len(sample_sorted) - 1)
    threshold = sample_sorted[idx].item()
    del sample, sample_sorted, samples

    total_pruned = 0
    total_weights = 0
    for name, param in model.named_parameters():
        if param.dim() < 2:
            continue
        if any(p in name for p in protect_layers):
            continue
        w = param.data.cpu().float()
        mask = w.abs() > threshold
        param.data.copy_((w * mask.float()).to(param.dtype).to(param.device))
        total_pruned += (~mask).sum().item()
        total_weights += w.numel()

    actual = 100.0 * total_pruned / total_weights if total_weights > 0 else 0
    print('    [prune] {:.1f}% zerados ({:,}/{:,})'.format(actual, total_pruned, total_weights))
    return actual


def finetune(model, chunks, n_steps=300, lr=5e-4, batch_size=64):
    """Fine-tune com AdamW - batch grande para saturar H100."""
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_steps)

    losses = []
    t0 = time.time()
    for step in range(n_steps):
        batch_chunks = random.choices(chunks, k=batch_size)
        input_ids = torch.stack(batch_chunks).to(model.device if hasattr(model, 'device') else 'cuda')

        out = model(input_ids=input_ids, labels=input_ids)
        loss = out.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        losses.append(loss.item())

        if (step + 1) % 50 == 0:
            avg = sum(losses[-50:]) / len(losses[-50:])
            elapsed = time.time() - t0
            steps_per_sec = (step + 1) / elapsed
            eta = (n_steps - step - 1) / steps_per_sec
            print('    [ft] {}/{} | loss: {:.3f} | {:.1f} steps/s | ETA: {:.0f}s'.format(
                step + 1, n_steps, avg, steps_per_sec, eta))

    model.eval()
    optimizer.zero_grad(set_to_none=True)
    del optimizer, scheduler
    torch.cuda.empty_cache()

    final_loss = sum(losses[-50:]) / len(losses[-50:])
    total_time = time.time() - t0
    print('    [ft] Pronto em {:.0f}s | Loss final: {:.3f}'.format(total_time, final_loss))
    return final_loss


print('Funcoes definidas (H100 max speed).')
print('batch=64 x seq=128 = 8K tokens/step')

In [None]:
# CELL 5: BASELINE

print('=== BASELINE ===')
ppl_baseline = compute_ppl(model, val_chunks)
print('PPL baseline: {:.2f}'.format(ppl_baseline))

texts_baseline = test_gen(model, tokenizer, TEST_PROMPTS)
for t in texts_baseline:
    print('  {}'.format(t))

if device.type == 'cuda':
    print('VRAM: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))

In [None]:
# CELL 6: PRUNING PROGRESSIVO COM FINE-TUNE (H100 - MAX SPEED)
# batch=64 usa a VRAM da H100, 300 steps basta (mais dados por step)

SPARSITY_STEPS = [5, 10, 15, 20, 25, 30, 40, 50]
FT_STEPS = 300       # menos steps (cada step ve 64 chunks)
FT_LR = 5e-4
BATCH_SIZE = 64      # usa a VRAM da H100

results = []

results.append({
    'sparsity': 0,
    'actual_sparsity': 0.0,
    'ppl_before_ft': ppl_baseline,
    'ppl_after_ft': ppl_baseline,
    'ft_loss': 0,
    'sample': texts_baseline[0],
    'time_sec': 0,
})

tokens_per_level = FT_STEPS * BATCH_SIZE * SEQ_LEN
print('=== PRUNING PROGRESSIVO (H100 MAX SPEED) ===')
print('Fine-tune: {} steps x batch {} x seq {} = {:.1f}M tokens/nivel'.format(
    FT_STEPS, BATCH_SIZE, SEQ_LEN, tokens_per_level / 1e6))
print('AdamW lr={}'.format(FT_LR))
print()

for target_sp in SPARSITY_STEPS:
    print('=' * 50)
    print('NIVEL: {}%'.format(target_sp))
    t0 = time.time()

    print('  [1/4] Podando...')
    actual_sp = prune_magnitude(model, target_sp)

    print('  [2/4] PPL antes...')
    ppl_before = compute_ppl(model, val_chunks)
    print('    {:.2f}'.format(ppl_before))

    print('  [3/4] Fine-tuning...')
    ft_loss = finetune(model, chunks, n_steps=FT_STEPS, lr=FT_LR, batch_size=BATCH_SIZE)

    print('  [4/4] PPL depois + texto...')
    ppl_after = compute_ppl(model, val_chunks)
    texts = test_gen(model, tokenizer, TEST_PROMPTS)

    dt = time.time() - t0

    result = {
        'sparsity': target_sp,
        'actual_sparsity': actual_sp,
        'ppl_before_ft': ppl_before,
        'ppl_after_ft': ppl_after,
        'ft_loss': ft_loss,
        'sample': texts[0],
        'time_sec': dt,
    }
    results.append(result)

    recovery = ppl_before - ppl_after
    print()
    print('  {}%: PPL {:.2f} -> {:.2f} (recuperou {:.2f}) | vs base: {:+.1f}% | {:.0f}s'.format(
        target_sp, ppl_before, ppl_after, recovery,
        100 * (ppl_after / ppl_baseline - 1), dt))
    print('  Texto: {}'.format(texts[0][:80]))
    print()

    if ppl_after > ppl_baseline * 3:
        print('  *** PPL {:.1f}x baseline, parando. ***'.format(ppl_after / ppl_baseline))
        break

print('Finalizado.')

In [None]:
# CELL 7: TABELA DE RESULTADOS

print('=' * 80)
print('RESULTADOS: PRUNING PROGRESSIVO COM FINE-TUNE NO BITNET 2B')
print('=' * 80)
print()
print('{:<10} {:<10} {:<12} {:<12} {:<10} {}'.format(
    'Sparse%', 'Real%', 'PPL antes', 'PPL depois', 'vs Base', 'Texto'))
print('-' * 80)

for r in results:
    if r['sparsity'] == 0:
        delta_str = '-'
    else:
        delta = 100 * (r['ppl_after_ft'] / ppl_baseline - 1)
        delta_str = '{:+.1f}%'.format(delta)
    sample = r['sample'][:35].replace('\n', ' ')
    print('{:<10} {:<10} {:<12} {:<12} {:<10} {}'.format(
        '{}%'.format(r['sparsity']),
        '{:.1f}%'.format(r['actual_sparsity']),
        '{:.2f}'.format(r['ppl_before_ft']),
        '{:.2f}'.format(r['ppl_after_ft']),
        delta_str,
        sample))

print()
print('--- COMPARACAO ---')
print('Baseline (0%): PPL {:.2f}'.format(ppl_baseline))
print('Pruning cru 10% (sem fine-tune): PPL 6.94')
print()

# Melhor resultado
best = min(results, key=lambda x: x['ppl_after_ft'])
print('MELHOR: {}% esparsidade -> PPL {:.2f} ({:+.1f}% vs baseline)'.format(
    best['sparsity'],
    best['ppl_after_ft'],
    100 * (best['ppl_after_ft'] / ppl_baseline - 1)))

# Maximo esparso que ainda e melhor que baseline
better = [r for r in results if r['ppl_after_ft'] <= ppl_baseline and r['sparsity'] > 0]
if better:
    most_sparse = max(better, key=lambda x: x['sparsity'])
    print('MAX esparsidade melhor que baseline: {}% (PPL {:.2f})'.format(
        most_sparse['sparsity'], most_sparse['ppl_after_ft']))
else:
    within = [r for r in results if r['ppl_after_ft'] <= ppl_baseline * 1.1 and r['sparsity'] > 0]
    if within:
        most_sparse = max(within, key=lambda x: x['sparsity'])
        print('MAX esparsidade dentro de 10% do baseline: {}% (PPL {:.2f})'.format(
            most_sparse['sparsity'], most_sparse['ppl_after_ft']))

In [None]:
# CELL 8: TESTE DETALHADO DO MELHOR NIVEL

def chat(user_msg, system_msg='You are a helpful AI assistant.', max_tokens=200):
    messages = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    input_len = inputs['input_ids'].shape[-1]
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(out[0][input_len:], skip_special_tokens=True)

# Pegar esparsidade atual do modelo (ultimo nivel treinado)
current_sp = results[-1]['sparsity']
current_ppl = results[-1]['ppl_after_ft']

print('=== TESTE DETALHADO: {}% ESPARSIDADE (com fine-tune) ==='.format(current_sp))
print('PPL: {:.2f}'.format(current_ppl))
print()

all_prompts = [
    'The capital of France is',
    'Water boils at',
    'The largest planet in the solar system is',
    'Python is a programming language that',
    'In 1969, humans first',
]

print('--- Completar frase ---')
for p in all_prompts:
    inp = tokenizer(p, return_tensors='pt').to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inp, max_new_tokens=30,
            do_sample=False, pad_token_id=tokenizer.eos_token_id)
    print('  {}'.format(tokenizer.decode(out[0], skip_special_tokens=True)))
print()

print('--- Chat ---')
perguntas = [
    'What is the capital of France?',
    'Explain what machine learning is in 2 sentences.',
    'Write a Python function that checks if a number is prime.',
]
for p in perguntas:
    print('User:', p)
    print('BitNet ({}% sparse):'.format(current_sp), chat(p))
    print('-' * 60)

In [None]:
# CELL 9: SALVAR RESULTADOS

report = {
    'model': MODEL_ID,
    'date': '2026-02-06',
    'experiment': 'progressive_pruning_finetune_bitnet_2b_h100',
    'config': {
        'gpu': 'H100',
        'ft_steps_per_level': FT_STEPS,
        'lr': FT_LR,
        'batch_size': BATCH_SIZE,
        'seq_len': SEQ_LEN,
        'optimizer': 'AdamW',
        'weight_decay': 0.01,
        'scheduler': 'CosineAnnealingLR',
        'torch_compile': True,
        'tf32': True,
        'dataset': 'wikitext-2',
    },
    'baseline_ppl': ppl_baseline,
    'results': [],
}

for r in results:
    report['results'].append({
        'sparsity': r['sparsity'],
        'actual_sparsity': r['actual_sparsity'],
        'ppl_before_ft': r['ppl_before_ft'],
        'ppl_after_ft': r['ppl_after_ft'],
        'ft_loss': r['ft_loss'],
        'sample_text': r['sample'],
        'time_sec': r['time_sec'],
    })

with open('progressive_sparsity_results.json', 'w') as f:
    json.dump(report, f, indent=2)

print('Salvo em progressive_sparsity_results.json')
print()
print('=== RESUMO FINAL ===')
print('Baseline PPL: {:.2f}'.format(ppl_baseline))
print()
for r in results:
    if r['sparsity'] == 0:
        continue
    delta = 100 * (r['ppl_after_ft'] / ppl_baseline - 1)
    print('  {}% -> PPL {:.2f} antes, {:.2f} depois ({:+.1f}% vs baseline) [{:.0f}s]'.format(
        r['sparsity'], r['ppl_before_ft'], r['ppl_after_ft'], delta, r['time_sec']))

In [None]:
# CELL 10: DOWNLOAD RESULTADOS
import os

filepath = os.path.abspath('progressive_sparsity_results.json')
print('Arquivo salvo em: {}'.format(filepath))
print('Tamanho: {:.1f} KB'.format(os.path.getsize(filepath) / 1024))

try:
    from google.colab import files
    files.download('progressive_sparsity_results.json')
    print('Download Colab iniciado.')
except ImportError:
    print('Nao esta no Colab. Copie o arquivo manualmente do caminho acima.')