<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2vec_Llama_3_1_8b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

guardrail

In [5]:
# stop TorchAO
import os
os.environ["TRANSFORMERS_NO_TORCHAO"] = "1"

In [None]:
# to match Colab 2025.10 (Torch 2.8.0+cu126)
!pip -q install -U triton==3.2.0 bitsandbytes==0.44.1 \
    transformers==4.45.2 accelerate==0.34.2 peft==0.13.2 --progress-bar off

import torch, bitsandbytes as bnb, triton
print("torch:", torch.__version__, "| cuda:", torch.version.cuda)
print("bnb:", bnb.__version__, "| triton:", triton.__version__)

In [None]:
from getpass import getpass
from huggingface_hub import login

HF_TOKEN = getpass("Paste your HF token (hidden): ")
login(token=HF_TOKEN, add_to_git_credential=True)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B"  # or "...-8B-Instruct"

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
    token=HF_TOKEN
)

In [None]:
from huggingface_hub import login
from getpass import getpass
HF_TOKEN = getpass("HF token (hidden): ")
login(token=HF_TOKEN, add_to_git_credential=True)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    token=HF_TOKEN,
)

# Colab/T4 sanity
model.config.attn_implementation = "eager"
model.config.tie_word_embeddings = True
if hasattr(model, "tie_weights"):
    model.tie_weights()

print("Loaded:", MODEL_NAME, "| dtype:", model.dtype, "| pad_token_id:", tok.pad_token_id)

In [1]:
import os, math, json, random, torch, torch.nn as nn
from torch.utils.data import Dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
                          DataCollatorForLanguageModeling, TrainingArguments, Trainer, set_seed)
from peft import LoraConfig, get_peft_model

# CONFIG
SEED = 42
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
WAKE_LEX_PATH = "/content/wake_lexicon.txt"
CORPUS_TXT    = "/content/FW_TEXT.txt"
RUN_DIR = "/content/wake_llama_embs"
SEQ_LEN = 1024
STRIDE  = 1024
MAX_STEPS = 1100
LOG_STEPS = 20
SAVE_STEPS = 200
LR = 8e-4
GRAD_ACCUM = 8
REPULSION_W = 0.05
TARGET_NORM = None
MAX_ROW_NORM = None
REPORT_SAMPLE = 1500

os.makedirs(RUN_DIR, exist_ok=True); set_seed(SEED)
print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| GPU:", torch.cuda.get_device_name(0))

def read_lines(p, fb=None):
    if not os.path.exists(p): return fb or []
    return [x.strip() for x in open(p, encoding="utf-8") if x.strip()]

# Tiny torch-only dataset
class BlockDataset(Dataset):
    def __init__(self, path, tok, seq_len=1024, stride=1024):
        if not os.path.exists(path):
            stub = ("riverrun, past Eve and Adam’s, from swerve of shore to bend of bay, "
                    "brings us by a commodius vicus of recirculation to Howth Castle and Environs. ")
            text = stub * 2000
        else:
            text = open(path, "r", encoding="utf-8").read()
        ids = tok(text, add_special_tokens=False)["input_ids"]
        blocks=[]
        for i in range(0, max(1, len(ids)-seq_len), stride):
            chunk = ids[i:i+seq_len]
            if len(chunk) >= seq_len//2:
                blocks.append(chunk[:seq_len])
        self.blocks = blocks
    def __len__(self): return len(self.blocks)
    def __getitem__(self, idx):
        ids = torch.tensor(self.blocks[idx], dtype=torch.long)
        return {"input_ids": ids, "labels": ids.clone()}

# 4-bit load
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                         bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb,
                                             torch_dtype=torch.float16, device_map="auto")
model.config.attn_implementation = "eager"  # T4 stable
if hasattr(model, "tie_weights"): model.tie_weights()
model.config.tie_word_embeddings = True

# Minimal PEFT adapter to satisfy “no pure-quant training”
peft_cfg = LoraConfig(r=1, lora_alpha=1, lora_dropout=0.0,
                      target_modules=["q_proj"], bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, peft_cfg)

# Wake lex
wake = read_lines(WAKE_LEX_PATH)
missing = [t for t in wake if tok.convert_tokens_to_ids(t)==tok.unk_token_id]
num_added = tok.add_tokens(missing, special_tokens=False)
old_vocab = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok))
wte = model.get_input_embeddings()
if hasattr(model, "lm_head"): model.lm_head.weight = wte.weight  # re-tie

# Spherical kick
with torch.no_grad():
    base = wte.weight[:old_vocab]; dim = base.shape[1]
    std = base.std().item(); base_radius = std * math.sqrt(dim)
    target_radius = TARGET_NORM or (1.5 * base_radius)
    if num_added>0:
        new = torch.randn((num_added, dim), device=wte.weight.device)
        new = new/(new.norm(dim=1, keepdim=True)+1e-8)*target_radius
        wte.weight.data[old_vocab:old_vocab+num_added] = new
print(f"[Init] new rows: {num_added} | target L2 ≈ {target_radius:.3f}")

# Trainables: only embeddings
for n,p in model.named_parameters(): p.requires_grad=False
wte.weight.requires_grad=True
new_rows = torch.arange(old_vocab, old_vocab+num_added, device=wte.weight.device) if num_added>0 else None
base_rows = torch.arange(0, old_vocab, device=wte.weight.device)
def mask_grad(grad):
    if grad is None or new_rows is None: return grad
    grad[base_rows]=0; return grad
wte.weight.register_hook(mask_grad)

def clamp_rows_(emb, max_norm):
    if max_norm is None or new_rows is None: return
    rows = emb.weight.data[old_vocab:old_vocab+num_added]
    norms = rows.norm(dim=1, keepdim=True).clamp_min(1e-8)
    scale = (max_norm/norms).clamp_max(1.0)
    emb.weight.data[old_vocab:old_vocab+num_added] = rows*scale

# Data
train_ds = BlockDataset(CORPUS_TXT, tok, seq_len=SEQ_LEN, stride=STRIDE)
coll = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
model.gradient_checkpointing_enable()

# Geometry (pre)
@torch.no_grad()
def isotropy(M):
    u, s, v = torch.pca_lowrank(M - M.mean(0, keepdim=True), q=min(128, M.shape[1]-1))
    return float((s.max() / s.min().clamp_min(1e-8)))
@torch.no_grad()
def pip_loss(A, B):
    return float(torch.norm((A@A.T)-(B@B.T), p='fro')/(A.shape[0]**2))
@torch.no_grad()
def topk_overlap(M1, M2, k=10, sample=1000):
    W1 = M1/(M1.norm(dim=1, keepdim=True)+1e-8); W2 = M2/(M2.norm(dim=1, keepdim=True)+1e-8)
    vocab = W1.shape[0]; idxs = random.sample(range(vocab), min(sample, vocab))
    acc = 0.0
    for i in idxs:
        c1 = torch.topk(W1 @ W1[i], k+1).indices.tolist(); c1=[j for j in c1 if j!=i][:k]
        c2 = torch.topk(W2 @ W2[i], k+1).indices.tolist(); c2=[j for j in c2 if j!=i][:k]
        acc += len(set(c1)&set(c2))/k
    return float(acc/len(idxs))

with torch.no_grad():
    pre_full = wte.weight.detach().clone().cpu()
    pre_new  = pre_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, pre_full.shape[1])

# Trainer
class EmbOnlyTrainer(Trainer):
    def create_optimizer(self):
        from torch.optim import AdamW
        if not hasattr(self, "optimizer") or self.optimizer is None:
            self.optimizer = AdamW([{"params": [wte.weight], "lr": LR, "weight_decay": 0.0}],
                                   betas=(0.9, 0.999), eps=1e-8)
        return self.optimizer
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs); loss = out.loss
        if num_added and num_added>1 and REPULSION_W>0:
            E = model.get_input_embeddings().weight[old_vocab:old_vocab+num_added]
            E = E - E.mean(0, keepdim=True); E = E/(E.norm(dim=1, keepdim=True)+1e-8)
            sims = (E @ E.t()); repul = (sims - torch.eye(E.shape[0], device=E.device)).pow(2).mean()
            loss = loss + REPULSION_W*repul
        return (loss, out) if return_outputs else loss
    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        clamp_rows_(model.get_input_embeddings(), MAX_ROW_NORM)
        return out

args = TrainingArguments(
    output_dir=RUN_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    max_steps=MAX_STEPS,
    warmup_steps=max(20, MAX_STEPS//20),
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    fp16=True,
    bf16=False,
    logging_steps=LOG_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    eval_strategy="no",
    report_to="none",
    ataloader_pin_memory=False,
    gradient_checkpointing=True,
)

trainer = EmbOnlyTrainer(model=model, args=args, data_collator=coll, train_dataset=train_ds)
print(f"[Run] emb-only | base={MODEL_NAME} | steps={MAX_STEPS} | seq_len={SEQ_LEN} | accum={GRAD_ACCUM} | LR={LR}")
trainer.train()

# Save
save_dir = os.path.join(RUN_DIR, "embedding_only"); os.makedirs(save_dir, exist_ok=True)
torch.save(wte.weight.detach().cpu(), os.path.join(save_dir, "embed_tokens.pt"))
with open(os.path.join(save_dir, "added_tokens.json"), "w") as f:
    json.dump({"added_tokens": missing, "old_vocab": old_vocab, "num_added": num_added}, f, indent=2)
tok.save_pretrained(RUN_DIR)

# Geometry (post)
with torch.no_grad():
    post_full = wte.weight.detach().clone().cpu()
    post_new  = post_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, pre_full.shape[1])

report = {
  "model": MODEL_NAME, "added_tokens": int(num_added), "old_vocab": int(old_vocab),
  "pip_loss_full": pip_loss(pre_full, post_full),
  "topk_overlap_all": topk_overlap(pre_full, post_full, k=10, sample=min(REPORT_SAMPLE, pre_full.shape[0]-1)),
  "isotropy_pre": isotropy(pre_full), "isotropy_post": isotropy(post_full),
  "pip_loss_new_rows": (pip_loss(pre_new, post_new) if num_added>1 else None),
  "isotropy_new_rows": (isotropy(post_new) if num_added>1 else None),
}
json.dump(report, open(os.path.join(RUN_DIR, "geometry_report.json"), "w"), indent=2)
print("\n=== GEOMETRY REPORT ===")
for k,v in report.items(): print(f"{k}: {v}")


KeyboardInterrupt: 