# RPT - BitNet b1.58 2B-4T (Microsoft)
#
# Modelo oficial: 2B parametros, treinado do zero com pesos ternarios {-1, 0, +1}.
# Usando versao bf16 (pesos desempacotados, ~4GB VRAM).
#
# IMPORTANTE: Rode Cell 1, depois REINICIE o runtime, depois rode Cell 1 de novo e continue.

In [None]:
# CELL 1: SETUP
# Na primeira vez, o pip install roda. Reinicie o runtime e rode esta cell de novo.
!pip install -q torch torchvision
!pip install -q git+https://github.com/huggingface/transformers.git accelerate

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    mem = torch.cuda.get_device_properties(0).total_memory
    print('VRAM: {:.1f} GB'.format(mem / 1e9))
    print('BF16 suportado:', torch.cuda.is_bf16_supported())

In [None]:
# CELL 2: CARREGAR MODELO
MODEL_ID = 'microsoft/bitnet-b1.58-2B-4T-bf16'

print('Carregando tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print('Carregando modelo BitNet 2B...')
# T4 suporta bfloat16 via emulacao (funciona, so e mais lento)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map='auto'
)
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print('Parametros: {:,.0f} ({:.1f}B)'.format(n_params, n_params / 1e9))
if device.type == 'cuda':
    print('VRAM usada: {:.1f} GB'.format(torch.cuda.memory_allocated() / 1e9))
print('Pronto!')

In [None]:
# CELL 3: TESTE BASICO (completar frase)
TEST_PROMPTS = [
    'The capital of France is',
    'Water boils at',
    'The largest planet in the solar system is',
    'Python is a programming language that',
    'In 1969, humans first',
]

print('=== COMPLETAR FRASE ===')
for prompt in TEST_PROMPTS:
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=30,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    print('  {}'.format(text))
    print()

In [None]:
# CELL 4: TESTE CHAT
def chat(user_msg, system_msg='You are a helpful AI assistant.', max_tokens=200):
    messages = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    input_len = inputs['input_ids'].shape[-1]
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(out[0][input_len:], skip_special_tokens=True)

print('=== TESTE CHAT ===')
perguntas = [
    'What is the capital of France?',
    'Explain what BitNet is in 2 sentences.',
    'Write a Python function that checks if a number is prime.',
]

for p in perguntas:
    print('User:', p)
    print('BitNet:', chat(p))
    print('-' * 60)

In [None]:
# CELL 5: ARQUITETURA
import collections

print('=== ARQUITETURA ===')
print('Modelo:', type(model).__name__)

layer_types = collections.Counter()
for name, m in model.named_modules():
    layer_types[type(m).__name__] += 1

print('\nCamadas:')
for ltype, count in layer_types.most_common(15):
    print('  {:>4}x {}'.format(count, ltype))

print('\nPrimeiros parametros:')
for i, (name, param) in enumerate(model.named_parameters()):
    if i >= 10:
        print('  ...')
        break
    print('  {} | {} | dtype={}'.format(
        name[:55], list(param.shape), param.dtype))

In [None]:
# CELL 6: MODO INTERATIVO
print('=== MODO INTERATIVO ===')
print('Digite uma pergunta (ou "sair"):')
print()

while True:
    try:
        user_input = input('Voce > ').strip()
        if not user_input or user_input.lower() in ('sair', 'exit', 'q'):
            break
        response = chat(user_input)
        print('BitNet > {}'.format(response))
        print()
    except KeyboardInterrupt:
        break

print('Ate mais!')