# RPT - Validacao de Predictive Coding no BitNet 2B
#
# Hipotese: Residual connections em transformers sao codificacao preditiva.
# Cada camada faz uma pequena correcao (output - input). Se essas correcoes
# sao esparsas e podemos zerar as pequenas sem perder qualidade, validamos
# que o modelo opera como codificacao preditiva.
#
# Diferenca dos testes anteriores:
# - Antes: modificavamos PESOS (magnitude pruning)
# - Agora: modificamos ATIVACOES durante inferencia (residual pruning via hooks)
# - Nao precisa fine-tune! Cada teste e reversivel (remove hook)
#
# V2: Pruning por PERCENTIL (zerar os X% menores de cada camada)
# V1 usava thresholds absolutos que nao estressavam o modelo (magnitudes 28-1199)
#
# IMPORTANTE: Rode Cell 1, REINICIE o runtime, rode Cell 1 de novo e continue.

In [None]:
# CELL 1: SETUP
!pip install -q torch torchvision
!pip install -q git+https://github.com/huggingface/transformers.git accelerate datasets

import torch
import time
import json
import collections
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    mem = torch.cuda.get_device_properties(0).total_memory
    print('VRAM: {:.1f} GB'.format(mem / 1e9))

In [None]:
# CELL 2: CARREGAR MODELO
MODEL_ID = 'microsoft/bitnet-b1.58-2B-4T-bf16'

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

print('Carregando tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print('Carregando modelo BitNet 2B...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map='auto'
)
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print('Parametros: {:,.0f} ({:.1f}B)'.format(n_params, n_params / 1e9))

# Inspecionar arquitetura
print('\n=== ARQUITETURA ===')
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
print('Camadas: {}'.format(num_layers))
print('Hidden size: {}'.format(hidden_size))

# Detectar nomes das camadas (robusto)
modules_dict = dict(model.named_modules())
layer_names = []
for i in range(num_layers):
    # Tentar padroes comuns de HuggingFace
    for pattern in ['model.layers.{}', 'transformer.h.{}', 'gpt_neox.layers.{}']:
        name = pattern.format(i)
        if name in modules_dict:
            layer_names.append(name)
            break

print('Blocos transformer encontrados: {}'.format(len(layer_names)))
if layer_names:
    print('  Primeiro: {}'.format(layer_names[0]))
    print('  Ultimo: {}'.format(layer_names[-1]))
else:
    print('  ERRO: Nenhum bloco encontrado! Listando modulos de nivel 2:')
    for name in modules_dict:
        if name.count('.') == 2:
            print('    {}'.format(name))

# Teste rapido: verificar que hooks vao funcionar
if layer_names:
    test_module = modules_dict[layer_names[0]]
    print('  Tipo do bloco: {}'.format(type(test_module).__name__))

if device.type == 'cuda':
    print('VRAM usada: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))
print('Pronto!')

In [None]:
# CELL 3: DATASET + FUNCOES BASE
from datasets import load_dataset

print('Carregando WikiText-2 validacao...')
val_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation')

SEQ_LEN = 128
val_ids = []
for example in val_dataset:
    text = example['text'].strip()
    if len(text) < 20:
        continue
    val_ids.extend(tokenizer.encode(text, add_special_tokens=False))

val_chunks = []
for i in range(0, len(val_ids) - SEQ_LEN, SEQ_LEN):
    val_chunks.append(torch.tensor(val_ids[i:i + SEQ_LEN], dtype=torch.long))

print('Val chunks: {:,}'.format(len(val_chunks)))

# Textos para analise de ativacoes
sample_texts = []
for example in val_dataset:
    text = example['text'].strip()
    if len(text) > 100:
        sample_texts.append(text[:500])
    if len(sample_texts) >= 20:
        break
print('Textos para analise: {}'.format(len(sample_texts)))

TEST_PROMPTS = [
    'The capital of France is',
    'Water boils at',
    'The largest planet in the solar system is',
]


def compute_ppl(model, val_chunks, max_batches=50):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    for chunk in val_chunks[:max_batches]:
        input_ids = chunk.unsqueeze(0).to(model.device if hasattr(model, 'device') else 'cuda')
        with torch.no_grad():
            out = model(input_ids=input_ids, labels=input_ids)
        total_loss += out.loss.item() * (chunk.shape[0] - 1)
        total_tokens += chunk.shape[0] - 1
    return torch.exp(torch.tensor(total_loss / total_tokens)).item()


def test_gen(model, tokenizer, prompts):
    model.eval()
    results = []
    for p in prompts:
        inp = tokenizer(p, return_tensors='pt').to(model.device if hasattr(model, 'device') else 'cuda')
        with torch.no_grad():
            out = model.generate(
                **inp, max_new_tokens=25,
                do_sample=False, pad_token_id=tokenizer.eos_token_id)
        results.append(tokenizer.decode(out[0], skip_special_tokens=True))
    return results


print('Funcoes definidas.')

In [None]:
# CELL 4: CAPTURAR ATIVACOES (FORWARD HOOKS)

# Cache do dict de modulos (criado uma vez so)
_modules_dict = dict(model.named_modules())


def _get_module(name):
    """Busca modulo pelo nome no cache."""
    return _modules_dict.get(name)


class ResidualCapture:
    """Captura residual updates de cada camada do transformer.
    
    Em transformers: output = input + correction
    A 'correction' e a correcao preditiva (attn + mlp).
    Predictive coding diz que essa correcao deve ser esparsa.
    """
    
    def __init__(self):
        self.hooks = []
        self.stats = {}
    
    def register(self, layer_names):
        for layer_name in layer_names:
            module = _get_module(layer_name)
            if module is None:
                print('  AVISO: modulo {} nao encontrado'.format(layer_name))
                continue
            self.stats[layer_name] = []
            handle = module.register_forward_hook(self._make_hook(layer_name))
            self.hooks.append(handle)
        print('Hooks registrados: {}'.format(len(self.hooks)))
    
    def _make_hook(self, layer_name):
        stats_list = self.stats[layer_name]
        def hook_fn(module, input, output):
            x_in = input[0] if isinstance(input, tuple) else input
            x_out = output[0] if isinstance(output, tuple) else output
            
            if not isinstance(x_in, torch.Tensor) or not isinstance(x_out, torch.Tensor):
                return
            if x_in.shape != x_out.shape:
                return
            
            residual = (x_out - x_in).detach()
            r_abs = residual.abs()
            
            stats_list.append({
                'mean': r_abs.mean().item(),
                'std': r_abs.std().item(),
                'max': r_abs.max().item(),
                'pct_zero': (r_abs == 0).float().mean().item() * 100,
                'pct_below_001': (r_abs < 0.001).float().mean().item() * 100,
                'pct_below_01': (r_abs < 0.01).float().mean().item() * 100,
                'pct_below_05': (r_abs < 0.05).float().mean().item() * 100,
                'pct_below_1': (r_abs < 0.1).float().mean().item() * 100,
            })
        return hook_fn
    
    def get_summary(self):
        summary = {}
        for layer_name, stats_list in self.stats.items():
            if not stats_list:
                continue
            keys = stats_list[0].keys()
            avg = {}
            for k in keys:
                vals = [s[k] for s in stats_list]
                avg[k] = sum(vals) / len(vals)
            avg['n_samples'] = len(stats_list)
            summary[layer_name] = avg
        return summary
    
    def remove(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()
        self.stats.clear()


class ResidualPercentilePruner:
    """Aplica pruning por PERCENTIL nas ativacoes durante inferencia.
    
    Em vez de threshold absoluto fixo, zera os X% menores de cada camada.
    Isso garante que realmente estressamos o modelo independente da escala.
    
    Exemplo: percentile=70 -> zera os 70% menores, mantem os 30% maiores.
    """
    
    def __init__(self):
        self.hooks = []
        self.total_activations = 0
        self.total_pruned = 0
    
    def apply(self, layer_names, percentile):
        self.total_activations = 0
        self.total_pruned = 0
        
        for layer_name in layer_names:
            module = _get_module(layer_name)
            if module is None:
                continue
            handle = module.register_forward_hook(self._make_prune_hook(percentile))
            self.hooks.append(handle)
    
    def _make_prune_hook(self, percentile):
        pruner = self
        q = percentile / 100.0
        
        def hook_fn(module, input, output):
            x_in = input[0] if isinstance(input, tuple) else input
            
            if isinstance(output, tuple):
                x_out = output[0]
                rest = output[1:]
            else:
                x_out = output
                rest = None
            
            if not isinstance(x_in, torch.Tensor) or not isinstance(x_out, torch.Tensor):
                return output
            if x_in.shape != x_out.shape:
                return output
            
            residual = x_out - x_in
            r_abs = residual.abs()
            
            # Threshold dinamico: percentil da distribuicao DESTA camada NESTE forward
            threshold = torch.quantile(r_abs.float().flatten(), q)
            
            mask = (r_abs >= threshold).to(residual.dtype)
            
            pruner.total_activations += r_abs.numel()
            pruner.total_pruned += (mask == 0).sum().item()
            
            new_out = x_in + residual * mask
            
            if rest is not None:
                return (new_out,) + rest
            return new_out
        return hook_fn
    
    def get_pruned_pct(self):
        if self.total_activations == 0:
            return 0.0
        return 100.0 * self.total_pruned / self.total_activations
    
    def reset_counters(self):
        self.total_activations = 0
        self.total_pruned = 0
    
    def remove(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


# Verificar que hooks funcionam com um forward pass de teste
print('Testando hooks...')
capture_test = ResidualCapture()
capture_test.register(layer_names)

test_input = tokenizer('Hello world', return_tensors='pt')
test_input = {k: v.to('cuda') for k, v in test_input.items()}
with torch.no_grad():
    _ = model(**test_input)

test_summary = capture_test.get_summary()
capture_test.remove()

if test_summary:
    first_layer = list(test_summary.keys())[0]
    print('  OK! {} camadas capturadas. Layer 0 mean residual: {:.4f}'.format(
        len(test_summary), test_summary[first_layer]['mean']))
else:
    print('  ERRO: Hooks nao capturaram nada! Verificar nomes das camadas.')

print('Classes definidas: ResidualCapture, ResidualPercentilePruner')

In [None]:
# CELL 5: BASELINE - ESPARSIDADE NATURAL DAS ATIVACOES

print('=== BASELINE: ESPARSIDADE NATURAL DAS ATIVACOES ===')
print('Capturando residual updates de {} camadas em {} textos...'.format(
    len(layer_names), len(sample_texts)))
print()

ppl_baseline = compute_ppl(model, val_chunks)
print('PPL baseline: {:.2f}'.format(ppl_baseline))
print()

capture = ResidualCapture()
capture.register(layer_names)

with torch.no_grad():
    for i, text in enumerate(sample_texts):
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        _ = model(**inputs)
        if (i + 1) % 5 == 0:
            print('  Processados {}/{} textos'.format(i + 1, len(sample_texts)))

summary = capture.get_summary()
capture.remove()

if not summary:
    print('ERRO: Nenhuma ativacao capturada! Hooks nao funcionaram.')
else:
    print()
    print('{:<25} {:>8} {:>8} {:>8} {:>8} {:>8}'.format(
        'Camada', 'Mean', '<0.001', '<0.01', '<0.05', '<0.1'))
    print('-' * 75)

    layer_indices = []
    means = []
    below_01 = []

    for layer_name in layer_names:
        if layer_name not in summary:
            continue
        s = summary[layer_name]
        idx = int(layer_name.split('.')[-1])
        layer_indices.append(idx)
        means.append(s['mean'])
        below_01.append(s['pct_below_01'])
        print('{:<25} {:>8.4f} {:>7.1f}% {:>7.1f}% {:>7.1f}% {:>7.1f}%'.format(
            'Layer {}'.format(idx),
            s['mean'],
            s['pct_below_001'],
            s['pct_below_01'],
            s['pct_below_05'],
            s['pct_below_1']))

    print()
    print('=== ANALISE ===')

    if len(means) > 1:
        first_half = sum(means[:len(means)//2]) / (len(means)//2)
        second_half = sum(means[len(means)//2:]) / (len(means) - len(means)//2)
        print('Magnitude media (primeiras {} camadas): {:.4f}'.format(len(means)//2, first_half))
        print('Magnitude media (ultimas {} camadas):   {:.4f}'.format(len(means) - len(means)//2, second_half))
        if second_half < first_half:
            print('-> Camadas profundas tem correcoes MENORES (confirma predictive coding)')
        else:
            print('-> Camadas profundas tem correcoes MAIORES (nao confirma predicao)')

    avg_below_01 = sum(below_01) / len(below_01) if below_01 else 0
    print('\nMedia de ativacoes < 0.01: {:.1f}%'.format(avg_below_01))
    if avg_below_01 > 10:
        print('-> Ativacoes sao naturalmente esparsas! Predictive coding plausivel.')
    else:
        print('-> Ativacoes nao sao muito esparsas naturalmente.')

In [None]:
# CELL 6: BASELINE GERACAO

print('=== BASELINE GERACAO ===')
texts_baseline = test_gen(model, tokenizer, TEST_PROMPTS)
for t in texts_baseline:
    print('  {}'.format(t))

if device.type == 'cuda':
    print('\nVRAM: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))

In [None]:
# CELL 7: TESTE PROGRESSIVO - PRUNING POR PERCENTIL
#
# Teste anterior (threshold absoluto) mostrou que magnitudes sao enormes (28-1199),
# entao thresholds ate 0.5 so zeraram 0.6% das ativacoes.
#
# Agora: zerar os X% MENORES de cada camada, independente da escala.
# percentile=70 -> zera os 70% menores, mantem os 30% maiores.
# Como ~50% ja sao naturalmente zero, percentile=50 e ~baseline.

PERCENTILES = [50, 60, 70, 75, 80, 85, 90, 95, 99]

results = []
results.append({
    'percentile': 0,
    'pct_pruned': 0.0,
    'ppl': ppl_baseline,
    'sample': texts_baseline[0],
    'time_sec': 0,
})

print('=== RESIDUAL PRUNING POR PERCENTIL ===')
print('Percentis: {}'.format(PERCENTILES))
print('percentile=X -> zera os X% menores de cada camada')
print('~50% das ativacoes ja sao naturalmente zero')
print()

for pct in PERCENTILES:
    print('=' * 50)
    print('PERCENTIL: {}% (manter top {}%)'.format(pct, 100 - pct))
    t0 = time.time()
    
    # Aplicar hooks
    pruner = ResidualPercentilePruner()
    pruner.apply(layer_names, pct)
    
    # Medir PPL
    print('  Medindo PPL...')
    ppl = compute_ppl(model, val_chunks)
    pct_pruned = pruner.get_pruned_pct()
    
    # Reset contadores antes de gerar texto
    pruner.reset_counters()
    
    # Gerar texto
    print('  Gerando texto...')
    texts = test_gen(model, tokenizer, TEST_PROMPTS)
    
    # Remover hooks
    pruner.remove()
    
    dt = time.time() - t0
    
    result = {
        'percentile': pct,
        'pct_pruned': pct_pruned,
        'ppl': ppl,
        'texts': texts,
        'sample': texts[0],
        'time_sec': dt,
    }
    results.append(result)
    
    delta = 100 * (ppl / ppl_baseline - 1)
    print('  P{}: {:.1f}% zerado | PPL {:.2f} ({:+.1f}% vs base) | {:.0f}s'.format(
        pct, pct_pruned, ppl, delta, dt))
    for j, t in enumerate(texts):
        print('  [{}] {}'.format(j + 1, t[:80]))
    print()
    
    # Se PPL ficou 10x pior, parar
    if ppl > ppl_baseline * 10:
        print('  *** PPL {:.1f}x baseline, parando. ***'.format(ppl / ppl_baseline))
        break

print('Finalizado.')

In [None]:
# CELL 8: RESULTADOS + COMPARACAO

print('=' * 80)
print('RESULTADOS: PREDICTIVE CODING (PERCENTILE PRUNING) NO BITNET 2B')
print('=' * 80)
print()
print('{:<12} {:<12} {:<12} {:<12} {}'.format(
    'Percentil', '% Zerado', 'PPL', 'vs Base', 'Texto'))
print('-' * 80)

for r in results:
    if r['percentile'] == 0:
        delta_str = '-'
        label = '0 (base)'
    else:
        delta = 100 * (r['ppl'] / ppl_baseline - 1)
        delta_str = '{:+.1f}%'.format(delta)
        label = 'P{}'.format(r['percentile'])
    sample = r['sample'][:35].replace('\n', ' ')
    print('{:<12} {:<12} {:<12} {:<12} {}'.format(
        label,
        '{:.1f}%'.format(r['pct_pruned']),
        '{:.2f}'.format(r['ppl']),
        delta_str,
        sample))

print()
print('=== COMPARACAO: PESO vs ATIVACAO ===')
print()
print('Esparsidade de PESOS (magnitude pruning + fine-tune):')
print('  10% pesos zerados = PPL -40.4% (MELHOR)')
print('  40% pesos zerados = PPL -29.3% (maximo usavel)')
print()
print('Esparsidade de ATIVACOES (residual pruning, SEM fine-tune):')

best = min(results, key=lambda x: x['ppl'])
print('  Melhor: P{} -> {:.1f}% zerado, PPL {:.2f} ({:+.1f}%)'.format(
    best['percentile'],
    best['pct_pruned'],
    best['ppl'],
    100 * (best['ppl'] / ppl_baseline - 1)))

usable = [r for r in results if r['ppl'] < ppl_baseline * 1.5 and r['percentile'] > 0]
if usable:
    most_pruned = max(usable, key=lambda x: x['pct_pruned'])
    print('  Max usavel (<50% PPL): P{} -> {:.1f}% zerado, PPL {:.2f}'.format(
        most_pruned['percentile'],
        most_pruned['pct_pruned'],
        most_pruned['ppl']))

# Encontrar onde degrada significativamente (>10% PPL)
degraded = [r for r in results if r['ppl'] > ppl_baseline * 1.1 and r['percentile'] > 0]
if degraded:
    first_bad = min(degraded, key=lambda x: x['percentile'])
    print('  Degrada (>10% PPL) a partir de: P{}'.format(first_bad['percentile']))

print()
print('=== CONCLUSAO ===')
if best['ppl'] < ppl_baseline * 0.99:
    print('PREDICTIVE CODING VALIDADO: zerar correcoes pequenas MELHORA o modelo!')
elif usable and most_pruned['pct_pruned'] > 60:
    print('PREDICTIVE CODING VALIDADO: modelo tolera {:.0f}% de ativacoes zeradas'.format(
        most_pruned['pct_pruned']))
    print('Correcoes residuais sao redundantes - confirma codificacao preditiva.')
elif usable and most_pruned['pct_pruned'] > 50:
    print('PREDICTIVE CODING PARCIAL: modelo tolera ate {:.0f}% zerado'.format(
        most_pruned['pct_pruned']))
    print('Alguma redundancia nas correcoes, mas limitada.')
else:
    print('PREDICTIVE CODING NAO CONFIRMADO: modelo sensivel a mudancas nas ativacoes.')

In [None]:
# CELL 9: SALVAR RESULTADOS

report = {
    'model': MODEL_ID,
    'date': '2026-02-06',
    'experiment': 'predictive_coding_percentile_pruning_bitnet_2b',
    'config': {
        'gpu': torch.cuda.get_device_name(0) if device.type == 'cuda' else 'cpu',
        'num_layers': len(layer_names),
        'seq_len': SEQ_LEN,
        'num_sample_texts': len(sample_texts),
        'dataset': 'wikitext-2',
        'percentiles_tested': PERCENTILES,
    },
    'natural_sparsity': {},
    'baseline_ppl': ppl_baseline,
    'results': [],
}

for layer_name, s in summary.items():
    report['natural_sparsity'][layer_name] = s

for r in results:
    entry = {
        'percentile': r['percentile'],
        'pct_pruned': r['pct_pruned'],
        'ppl': r['ppl'],
        'sample_text': r['sample'],
        'time_sec': r['time_sec'],
    }
    if 'texts' in r:
        entry['all_texts'] = r['texts']
    report['results'].append(entry)

filename = 'predictive_coding_percentile_bitnet2b_results.json'
with open(filename, 'w') as f:
    json.dump(report, f, indent=2)

print('Salvo em {}'.format(filename))
print()
print('=== RESUMO FINAL ===')
print('Baseline PPL: {:.2f}'.format(ppl_baseline))
for r in results:
    if r['percentile'] == 0:
        continue
    delta = 100 * (r['ppl'] / ppl_baseline - 1)
    print('  P{}: {:.1f}% zerado, PPL {:.2f} ({:+.1f}%)'.format(
        r['percentile'], r['pct_pruned'], r['ppl'], delta))

In [None]:
# CELL 10: DOWNLOAD RESULTADOS
import os

filename = 'predictive_coding_percentile_bitnet2b_results.json'
filepath = os.path.abspath(filename)
print('Arquivo salvo em: {}'.format(filepath))
print('Tamanho: {:.1f} KB'.format(os.path.getsize(filepath) / 1024))

try:
    from google.colab import files
    files.download(filename)
    print('Download Colab iniciado.')
except ImportError:
    print('Nao esta no Colab. Copie o arquivo manualmente do caminho acima.')