# RPT - Validacao de Criticalidade Auto-Organizada no BitNet 2B
#
# Criticalidade Auto-Organizada (SOC) preve que redes neurais otimas operam
# "na borda do caos" com branching ratio ~1.0.
#
# Medimos:
# 1. Branching ratio: ||output|| / ||input|| por camada (~1.0 = critico)
# 2. Perturbation test: ruido no embedding cresce, diminui ou se mantem?
# 3. Lyapunov exponent: taxa de crescimento de perturbacoes (~0 = critico)
#
# So medicao, sem treino. Forward hooks capturam ativacoes.
#
# IMPORTANTE: Rode Cell 1, REINICIE o runtime, rode Cell 1 de novo e continue.

In [None]:
# CELL 1: SETUP
!pip install -q torch torchvision
!pip install -q git+https://github.com/huggingface/transformers.git accelerate datasets

import torch
import time
import json
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    mem = torch.cuda.get_device_properties(0).total_memory
    print('VRAM: {:.1f} GB'.format(mem / 1e9))

In [None]:
# CELL 2: CARREGAR MODELO
MODEL_ID = 'microsoft/bitnet-b1.58-2B-4T-bf16'

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

print('Carregando tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print('Carregando modelo BitNet 2B...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map='auto'
)
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print('Parametros: {:,.0f} ({:.1f}B)'.format(n_params, n_params / 1e9))

num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
print('Camadas: {}'.format(num_layers))
print('Hidden size: {}'.format(hidden_size))

# Detectar nomes das camadas
_modules_dict = dict(model.named_modules())
layer_names = []
for i in range(num_layers):
    for pattern in ['model.layers.{}', 'transformer.h.{}', 'gpt_neox.layers.{}']:
        name = pattern.format(i)
        if name in _modules_dict:
            layer_names.append(name)
            break

print('Blocos transformer: {}'.format(len(layer_names)))

# Detectar embedding layer
embed_module = None
for embed_name in ['model.model.embed_tokens', 'model.embed_tokens', 'transformer.wte', 'gpt_neox.embed_in']:
    if embed_name in _modules_dict:
        embed_module = _modules_dict[embed_name]
        print('Embedding: {}'.format(embed_name))
        break

if embed_module is None:
    for name, mod in _modules_dict.items():
        if isinstance(mod, torch.nn.Embedding) and 'embed' in name.lower():
            embed_module = mod
            print('Embedding (fallback): {}'.format(name))
            break

if embed_module is None:
    print('AVISO: Embedding nao encontrado! Perturbation test nao vai funcionar.')


def _get_module(name):
    return _modules_dict.get(name)


if device.type == 'cuda':
    print('VRAM: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))
print('Pronto!')

In [None]:
# CELL 3: DATASET + FUNCOES BASE
from datasets import load_dataset

print('Carregando WikiText-2 validacao...')
val_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation')

SEQ_LEN = 128
val_ids = []
for example in val_dataset:
    text = example['text'].strip()
    if len(text) < 20:
        continue
    val_ids.extend(tokenizer.encode(text, add_special_tokens=False))

val_chunks = []
for i in range(0, len(val_ids) - SEQ_LEN, SEQ_LEN):
    val_chunks.append(torch.tensor(val_ids[i:i + SEQ_LEN], dtype=torch.long))

print('Val chunks: {:,}'.format(len(val_chunks)))

sample_texts = []
for example in val_dataset:
    text = example['text'].strip()
    if len(text) > 100:
        sample_texts.append(text[:500])
    if len(sample_texts) >= 20:
        break
print('Textos para analise: {}'.format(len(sample_texts)))


def compute_ppl(model, val_chunks, max_batches=50):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    for chunk in val_chunks[:max_batches]:
        input_ids = chunk.unsqueeze(0).to(model.device if hasattr(model, 'device') else 'cuda')
        with torch.no_grad():
            out = model(input_ids=input_ids, labels=input_ids)
        total_loss += out.loss.item() * (chunk.shape[0] - 1)
        total_tokens += chunk.shape[0] - 1
    return torch.exp(torch.tensor(total_loss / total_tokens)).item()


print('Funcoes definidas.')

In [None]:
# CELL 4: CLASSES DE ANALISE

class ActivationAnalyzer:
    """Captura normas e ratios de ativacoes em cada camada.
    
    Branching ratio = ||output|| / ||input|| por camada.
    Na criticalidade: ratio ~1.0 (ativacoes nem crescem nem diminuem).
    """
    
    def __init__(self):
        self.hooks = []
        self.data = {}
    
    def register(self, layer_names):
        for name in layer_names:
            module = _get_module(name)
            if module is None:
                continue
            self.data[name] = []
            handle = module.register_forward_hook(self._make_hook(name))
            self.hooks.append(handle)
        print('ActivationAnalyzer: {} hooks'.format(len(self.hooks)))
    
    def _make_hook(self, name):
        data_list = self.data[name]
        def hook_fn(module, input, output):
            x_in = input[0] if isinstance(input, tuple) else input
            x_out = output[0] if isinstance(output, tuple) else output
            
            if not isinstance(x_in, torch.Tensor) or not isinstance(x_out, torch.Tensor):
                return
            if x_in.shape != x_out.shape:
                return
            
            x_in_f = x_in.detach().float()
            x_out_f = x_out.detach().float()
            correction = x_out_f - x_in_f
            
            in_norm = x_in_f.norm(dim=-1).mean().item()
            out_norm = x_out_f.norm(dim=-1).mean().item()
            corr_norm = correction.norm(dim=-1).mean().item()
            
            data_list.append({
                'in_norm': in_norm,
                'out_norm': out_norm,
                'corr_norm': corr_norm,
                'branching_ratio': out_norm / (in_norm + 1e-8),
                'correction_ratio': corr_norm / (in_norm + 1e-8),
            })
        return hook_fn
    
    def get_summary(self):
        summary = {}
        for name, data_list in self.data.items():
            if not data_list:
                continue
            avg = {}
            for key in data_list[0]:
                vals = [d[key] for d in data_list]
                avg[key] = sum(vals) / len(vals)
            avg['n_samples'] = len(data_list)
            summary[name] = avg
        return summary
    
    def remove(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()
        self.data.clear()


class PerturbationTest:
    """Mede como perturbacoes propagam pelas camadas.
    
    Adiciona ruido gaussiano ao embedding e mede delta relativo em cada camada.
    Na criticalidade: delta nao cresce nem diminui (Lyapunov ~0).
    Subcritico: delta diminui (Lyapunov < 0).
    Supercritico: delta cresce (Lyapunov > 0).
    """
    
    def __init__(self):
        self.layer_hooks = []
        self.embed_hook = None
        self.clean_outputs = {}
        self.deltas = {}
        self.mode = 'clean'
        self.epsilon = 0.01
    
    def register(self, layer_names, epsilon=0.01):
        self.epsilon = epsilon
        self.deltas = {name: [] for name in layer_names}
        
        for name in layer_names:
            module = _get_module(name)
            if module is None:
                continue
            handle = module.register_forward_hook(self._make_hook(name))
            self.layer_hooks.append(handle)
        print('PerturbationTest: {} hooks, eps={}'.format(len(self.layer_hooks), epsilon))
    
    def _make_hook(self, name):
        test = self
        def hook_fn(module, input, output):
            x_out = output[0] if isinstance(output, tuple) else output
            if not isinstance(x_out, torch.Tensor):
                return
            
            if test.mode == 'clean':
                test.clean_outputs[name] = x_out.detach().clone()
            elif test.mode == 'perturbed':
                clean = test.clean_outputs.get(name)
                if clean is not None and clean.shape == x_out.shape:
                    delta = (x_out.detach().float() - clean.float()).norm().item()
                    clean_norm = clean.float().norm().item()
                    test.deltas[name].append(delta / (clean_norm + 1e-8))
        return hook_fn
    
    def _perturb_embed(self, module, input, output):
        noise = torch.randn_like(output) * self.epsilon
        return output + noise
    
    def run(self, model, tokenizer, embed_module, texts):
        for i, text in enumerate(texts):
            inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
            inputs = {k: v.to('cuda') for k, v in inputs.items()}
            
            # Pass 1: clean
            self.mode = 'clean'
            with torch.no_grad():
                _ = model(**inputs)
            
            # Pass 2: perturbed
            self.mode = 'perturbed'
            self.embed_hook = embed_module.register_forward_hook(self._perturb_embed)
            with torch.no_grad():
                _ = model(**inputs)
            self.embed_hook.remove()
            self.embed_hook = None
            
            self.clean_outputs.clear()
            
            if (i + 1) % 5 == 0:
                print('  Perturbation: {}/{}'.format(i + 1, len(texts)))
    
    def get_results(self):
        results = {}
        for name, delta_list in self.deltas.items():
            if delta_list:
                results[name] = sum(delta_list) / len(delta_list)
        return results
    
    def remove(self):
        for h in self.layer_hooks:
            h.remove()
        self.layer_hooks.clear()
        if self.embed_hook:
            self.embed_hook.remove()
        self.clean_outputs.clear()
        self.deltas.clear()


# Teste rapido
print('Testando hooks...')
test_analyzer = ActivationAnalyzer()
test_analyzer.register(layer_names)

test_input = tokenizer('Hello world', return_tensors='pt')
test_input = {k: v.to('cuda') for k, v in test_input.items()}
with torch.no_grad():
    _ = model(**test_input)

test_summary = test_analyzer.get_summary()
test_analyzer.remove()

if test_summary:
    first = list(test_summary.values())[0]
    print('  OK! Branching ratio layer 0: {:.4f}'.format(first['branching_ratio']))
else:
    print('  ERRO: Hooks nao capturaram nada!')

print('Classes: ActivationAnalyzer, PerturbationTest')

In [None]:
# CELL 5: MEDIR BRANCHING RATIOS

print('=== BRANCHING RATIO: PROPAGACAO DE ATIVACOES ===')
print('Medindo em {} textos, {} camadas...'.format(len(sample_texts), len(layer_names)))
print()

ppl_baseline = compute_ppl(model, val_chunks)
print('PPL baseline: {:.2f}'.format(ppl_baseline))
print()

analyzer = ActivationAnalyzer()
analyzer.register(layer_names)

with torch.no_grad():
    for i, text in enumerate(sample_texts):
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        _ = model(**inputs)
        if (i + 1) % 5 == 0:
            print('  Processados {}/{}'.format(i + 1, len(sample_texts)))

summary = analyzer.get_summary()
analyzer.remove()

print()
print('{:<10} {:>10} {:>10} {:>10} {:>12} {:>12}'.format(
    'Camada', '||input||', '||output||', '||corr||', 'Branch.Ratio', 'Corr.Ratio'))
print('-' * 75)

branching_ratios = []
correction_ratios = []

for name in layer_names:
    if name not in summary:
        continue
    s = summary[name]
    idx = int(name.split('.')[-1])
    br = s['branching_ratio']
    cr = s['correction_ratio']
    branching_ratios.append(br)
    correction_ratios.append(cr)
    
    print('{:<10} {:>10.1f} {:>10.1f} {:>10.1f} {:>12.4f} {:>12.4f}'.format(
        'Layer {}'.format(idx),
        s['in_norm'], s['out_norm'], s['corr_norm'],
        br, cr))

print()
avg_br = sum(branching_ratios) / len(branching_ratios)
std_br = (sum((x - avg_br)**2 for x in branching_ratios) / len(branching_ratios)) ** 0.5

print('=== BRANCHING RATIO ===')
print('Media: {:.4f} (desvio: {:.4f})'.format(avg_br, std_br))
print('Min:   {:.4f} (Layer {})'.format(min(branching_ratios), branching_ratios.index(min(branching_ratios))))
print('Max:   {:.4f} (Layer {})'.format(max(branching_ratios), branching_ratios.index(max(branching_ratios))))

if abs(avg_br - 1.0) < 0.1:
    print('-> CRITICO: branching ratio ~1.0! Modelo opera na borda do caos.')
elif avg_br < 1.0:
    print('-> SUBCRITICO: ativacoes diminuem camada a camada.')
else:
    print('-> SUPERCRITICO: ativacoes crescem camada a camada.')

In [None]:
# CELL 6: TESTE DE PERTURBACAO

print('=== PERTURBATION TEST ===')
print('Adicionando ruido gaussiano (eps=0.01) ao embedding')
print('Medindo propagacao do ruido camada a camada...')
print()

if embed_module is None:
    print('ERRO: Embedding nao encontrado. Pulando perturbation test.')
    perturb_results = {}
    lyapunov = None
    amplification = None
else:
    perturb = PerturbationTest()
    perturb.register(layer_names, epsilon=0.01)
    perturb.run(model, tokenizer, embed_module, sample_texts)
    perturb_results = perturb.get_results()
    perturb.remove()
    
    print()
    print('{:<10} {:>15} {:>15}'.format('Camada', 'Delta Relativo', 'log(delta)'))
    print('-' * 45)
    
    deltas = []
    for name in layer_names:
        if name not in perturb_results:
            continue
        idx = int(name.split('.')[-1])
        d = perturb_results[name]
        deltas.append(d)
        log_d = math.log(d) if d > 0 else float('-inf')
        print('{:<10} {:>15.6f} {:>15.4f}'.format('Layer {}'.format(idx), d, log_d))
    
    print()
    if len(deltas) >= 2:
        # Lyapunov exponent: taxa media de crescimento por camada
        lyapunov = (math.log(deltas[-1] + 1e-12) - math.log(deltas[0] + 1e-12)) / (len(deltas) - 1)
        amplification = deltas[-1] / (deltas[0] + 1e-12)
        
        print('=== LYAPUNOV EXPONENT ===')
        print('Lambda: {:.4f}'.format(lyapunov))
        print('Amplificacao total: {:.2f}x (layer 0 -> layer {})'.format(
            amplification, len(deltas) - 1))
        
        if abs(lyapunov) < 0.05:
            print('-> CRITICO: perturbacoes se propagam estavelmente (lambda ~0)')
        elif lyapunov < 0:
            print('-> SUBCRITICO: perturbacoes diminuem (lambda < 0)')
        else:
            print('-> SUPERCRITICO: perturbacoes crescem (lambda > 0)')
    else:
        lyapunov = None
        amplification = None
        print('Poucos dados para calcular Lyapunov.')

In [None]:
# CELL 7: RESULTADOS + CONCLUSAO

print('=' * 80)
print('RESULTADOS: CRITICALIDADE AUTO-ORGANIZADA NO BITNET 2B')
print('=' * 80)
print()

print('--- 1. BRANCHING RATIO ---')
print('Media: {:.4f}'.format(avg_br))
print('Desvio: {:.4f}'.format(std_br))
print('Faixa: {:.4f} - {:.4f}'.format(min(branching_ratios), max(branching_ratios)))

near_one = sum(1 for br in branching_ratios if 0.95 <= br <= 1.05)
print('Camadas com ratio 0.95-1.05: {}/{}'.format(near_one, len(branching_ratios)))
print()

if perturb_results and lyapunov is not None:
    print('--- 2. PERTURBATION TEST ---')
    print('Epsilon: 0.01')
    print('Lyapunov exponent: {:.4f}'.format(lyapunov))
    print('Amplificacao: {:.2f}x'.format(amplification))
    print()

print('--- CONCLUSAO ---')
is_critical_br = abs(avg_br - 1.0) < 0.1
is_critical_lyap = lyapunov is not None and abs(lyapunov) < 0.05

if is_critical_br and is_critical_lyap:
    print('SOC VALIDADO: BitNet 2B opera na criticalidade!')
    print('  - Branching ratio ~1.0 ({:.4f})'.format(avg_br))
    print('  - Perturbacoes estaveis (lambda={:.4f})'.format(lyapunov))
    print('  - Modelo esta na borda do caos como previsto pela teoria RPT')
elif is_critical_br:
    print('SOC PARCIAL: Branching ratio ~1.0, mas perturbacoes nao estaveis')
    print('  - Branching ratio: {:.4f}'.format(avg_br))
    if lyapunov is not None:
        print('  - Lyapunov: {:.4f} (deveria ser ~0)'.format(lyapunov))
elif is_critical_lyap:
    print('SOC PARCIAL: Perturbacoes estaveis, mas branching ratio != 1.0')
    print('  - Branching ratio: {:.4f} (deveria ser ~1.0)'.format(avg_br))
    print('  - Lyapunov: {:.4f}'.format(lyapunov))
else:
    regime = 'SUBCRITICO' if avg_br < 1.0 else 'SUPERCRITICO'
    print('SOC NAO CONFIRMADO: modelo e {}'.format(regime))
    print('  - Branching ratio: {:.4f}'.format(avg_br))
    if lyapunov is not None:
        print('  - Lyapunov: {:.4f}'.format(lyapunov))

print()
print('=== COMPARACAO COM RESULTADOS ANTERIORES ===')
print('Esparsidade de pesos 10%: PPL melhora -26% (VALIDADO)')
print('Predictive coding: ativacoes nao sao redundantes (NAO CONFIRMADO)')
print('Criticalidade: branching ratio = {:.4f}'.format(avg_br))

In [None]:
# CELL 8: SALVAR RESULTADOS

report = {
    'model': MODEL_ID,
    'date': '2026-02-06',
    'experiment': 'self_organized_criticality_bitnet_2b',
    'config': {
        'gpu': torch.cuda.get_device_name(0) if device.type == 'cuda' else 'cpu',
        'num_layers': len(layer_names),
        'seq_len': SEQ_LEN,
        'num_sample_texts': len(sample_texts),
        'perturbation_epsilon': 0.01,
        'dataset': 'wikitext-2',
    },
    'baseline_ppl': ppl_baseline,
    'branching_ratios': {},
    'perturbation_deltas': {},
    'summary': {
        'avg_branching_ratio': avg_br,
        'std_branching_ratio': std_br,
        'min_branching_ratio': min(branching_ratios),
        'max_branching_ratio': max(branching_ratios),
        'layers_near_critical': near_one,
    },
}

if lyapunov is not None:
    report['summary']['lyapunov_exponent'] = lyapunov
    report['summary']['total_amplification'] = amplification

for name in layer_names:
    if name in summary:
        report['branching_ratios'][name] = summary[name]
    if name in perturb_results:
        report['perturbation_deltas'][name] = perturb_results[name]

filename = 'criticality_bitnet2b_results.json'
with open(filename, 'w') as f:
    json.dump(report, f, indent=2)

print('Salvo em {}'.format(filename))
print()
print('=== RESUMO FINAL ===')
print('Branching ratio medio: {:.4f}'.format(avg_br))
if lyapunov is not None:
    print('Lyapunov exponent: {:.4f}'.format(lyapunov))
print('PPL baseline: {:.2f}'.format(ppl_baseline))

In [None]:
# CELL 9: DOWNLOAD RESULTADOS
import os

filename = 'criticality_bitnet2b_results.json'
filepath = os.path.abspath(filename)
print('Arquivo salvo em: {}'.format(filepath))
print('Tamanho: {:.1f} KB'.format(os.path.getsize(filepath) / 1024))

try:
    from google.colab import files
    files.download(filename)
    print('Download Colab iniciado.')
except ImportError:
    print('Nao esta no Colab. Copie o arquivo manualmente do caminho acima.')