<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2Vec_LLaMA_2_13B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2Vec: LLaMA-2-13B Embeddings fun (4-bit, NC-17)

**What this notebook does**
- Injects a custom fw lexicon into LLaMA-2-13B (4-bit, T4-safe)
- Trains only the input embedding matrix (lm_head hard-tied)
- Uses an aggressive spherical init for new tokens
- Masks gradients to update only new rows, keeping the base space fixed.
- Adds a repulsion regularizer so Wake tokens don’t collapse into a single cluster.
- Optional norm clamp to keep things wild but stable.
- Saves compact embedding-only artifacts and writes geometry receipts:
  - PIP loss (global geometry shift)
  - Isotropy (spectral health)
  - Top-k neighbor overlap (neighborhood reshuffle)

aim is to push maximum local semantic drift with minimal compute. TinyLlama is cute; this is not.


**Outputs**
- `embedding_only/embed_tokens.pt` — updated embedding weights
- `added_tokens.json` — which tokens were injected
- `geometry_report.json` — PIP / isotropy / overlap metrics


In [None]:
%capture
# Wake2Vec: fuck the embeddings edition
# lets see how far that t4 will go on this embedding rampage (PG 18+) — LLaMA ONLY, 4-bit, Embedding-Only

!pip -q install torch==2.4.0 bitsandbytes==0.43.3 transformers==4.45.2 accelerate==0.34.2 datasets==2.20.0 peft==0.13.2 --progress-bar off

import os, math, json, random, gc, torch, torch.nn as nn
from datasets import Dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
                          DataCollatorForLanguageModeling, TrainingArguments, Trainer, set_seed)

# CONFIG LLaMA
SEED = 42
MODEL_NAME = "meta-llama/Llama-2-13b-hf"
WAKE_LEX_PATH = "/mnt/data/wake_lexicon.txt"
CORPUS_TXT    = "/content/finnegans_wake.txt"
RUN_DIR = "/content/wake_llama_embs"
SEQ_LEN = 1024
MAX_STEPS = 1100
LOG_STEPS = 20
SAVE_STEPS = 200
LR = 8e-4
GRAD_ACCUM = 8
REPULSION_W = 0.05
TARGET_NORM = None            # e.g. 1.8 * base_radius
MAX_ROW_NORM = None           # e.g. 2.0 * base_radius
REPORT_SAMPLE = 1500

os.makedirs(RUN_DIR, exist_ok=True); set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

# helpers
def read_lines(p, fb=None):
    if not os.path.exists(p): return fb or []
    return [x.strip() for x in open(p, encoding="utf-8") if x.strip()]

def read_corpus(p):
    if not os.path.exists(p):
        s = ("riverrun, past Eve and Adam’s, from swerve of shore to bend of bay, "
             "brings us by a commodius vicus of recirculation to Howth Castle and Environs.")
        return Dataset.from_dict({"text":[s]*1000})
    txt = open(p, "r", encoding="utf-8").read()
    paras = [t.strip() for t in txt.split("\n") if t.strip()]
    return Dataset.from_dict({"text": paras})

def pack_causal(ex, tok, block):
    ids = tok("\n\n".join(ex["text"]), add_special_tokens=False)["input_ids"]
    chunks = [ids[i:i+block] for i in range(0, len(ids)-block, block)]
    return {"input_ids": chunks, "labels": chunks.copy()}

def isotropy(M):
    u, s, v = torch.pca_lowrank(M - M.mean(0, keepdim=True), q=min(128, M.shape[1]-1))
    return float((s.max() / s.min().clamp_min(1e-8)))

def pip_loss(A, B):
    return float(torch.norm((A@A.T)-(B@B.T), p='fro')/(A.shape[0]**2))

def topk_overlap(M1, M2, k=10, sample=1000):
    W1 = M1/(M1.norm(dim=1, keepdim=True)+1e-8); W2 = M2/(M2.norm(dim=1, keepdim=True)+1e-8)
    vocab = W1.shape[0]; idxs = random.sample(range(vocab), min(sample, vocab))
    acc = 0.0
    for i in idxs:
        c1 = torch.topk(W1 @ W1[i], k+1).indices.tolist(); c1=[j for j in c1 if j!=i][:k]
        c2 = torch.topk(W2 @ W2[i], k+1).indices.tolist(); c2=[j for j in c2 if j!=i][:k]
        acc += len(set(c1)&set(c2))/k
    return float(acc/len(idxs))

# LLaMA in 4-bit
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                         bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb,
                                             torch_dtype=torch.float16, device_map="auto")
if hasattr(model, "tie_weights"): model.tie_weights()
model.config.tie_word_embeddings = True

# Wake lex
wake = read_lines(WAKE_LEX_PATH)
missing = [t for t in wake if tok.convert_tokens_to_ids(t)==tok.unk_token_id]
num_added = tok.add_tokens(missing, special_tokens=False)
old_vocab = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok))
wte = model.get_input_embeddings()
if hasattr(model, "lm_head"): model.lm_head.weight = wte.weight  # ensure tie

# spherical kick init
with torch.no_grad():
    base = wte.weight[:old_vocab]; dim = base.shape[1]
    std = base.std().item(); base_radius = std * math.sqrt(dim)
    target_radius = TARGET_NORM or (1.5 * base_radius)
    if num_added>0:
        new = torch.randn((num_added, dim), device=wte.weight.device)
        new = new/(new.norm(dim=1, keepdim=True)+1e-8)*target_radius
        wte.weight.data[old_vocab:old_vocab+num_added] = new

# trainables: ONLY new rows
for n,p in model.named_parameters(): p.requires_grad=False
wte.weight.requires_grad=True
new_rows = torch.arange(old_vocab, old_vocab+num_added, device=wte.weight.device) if num_added>0 else None
base_rows = torch.arange(0, old_vocab, device=wte.weight.device)

def mask_grad(grad):
    if grad is None or new_rows is None: return grad
    grad[base_rows]=0; return grad
wte.weight.register_hook(mask_grad)

def clamp_rows_(emb, max_norm):
    if max_norm is None or new_rows is None: return
    rows = emb.weight.data[old_vocab:old_vocab+num_added]
    norms = rows.norm(dim=1, keepdim=True).clamp_min(1e-8)
    scale = (max_norm/norms).clamp_max(1.0)
    emb.weight.data[old_vocab:old_vocab+num_added] = rows*scale

# data
ds = read_corpus(CORPUS_TXT)
split = ds.train_test_split(test_size=0.05, seed=SEED)
tok_tr = split["train"].map(lambda e: pack_causal(e, tok, SEQ_LEN), batched=True, remove_columns=["text"])
tok_ev = split["test"].map(lambda e: pack_causal(e, tok, SEQ_LEN), batched=True, remove_columns=["text"])
coll = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
model.gradient_checkpointing_enable()

# receipts (pre)
with torch.no_grad():
    pre_full = wte.weight.detach().clone().cpu()
    pre_new  = pre_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, pre_full.shape[1])

# trainer with repulsion
class EmbOnlyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs); loss = out.loss
        if num_added and num_added>1 and REPULSION_W>0:
            E = model.get_input_embeddings().weight[old_vocab:old_vocab+num_added]
            E = E - E.mean(0, keepdim=True); E = E/(E.norm(dim=1, keepdim=True)+1e-8)
            sims = (E @ E.t()); repul = (sims - torch.eye(E.shape[0], device=E.device)).pow(2).mean()
            loss = loss + REPULSION_W*repul
        return (loss, out) if return_outputs else loss
    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        clamp_rows_(model.get_input_embeddings(), MAX_ROW_NORM)
        return out

args = TrainingArguments(
    output_dir=RUN_DIR, per_device_train_batch_size=1, gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR, max_steps=MAX_STEPS, warmup_steps=max(20, MAX_STEPS//20),
    lr_scheduler_type="cosine", weight_decay=0.0, fp16=True, bf16=False,
    logging_steps=LOG_STEPS, save_steps=SAVE_STEPS, save_total_limit=3,
    evaluation_strategy="no", report_to="none", dataloader_pin_memory=False,
    gradient_checkpointing=True,
)
trainer = EmbOnlyTrainer(model=model, args=args, data_collator=coll, train_dataset=tok_tr)

print(f"[Run] LLaMA emb-only | steps={MAX_STEPS} | seq_len={SEQ_LEN} | accum={GRAD_ACCUM} | LR={LR}")
trainer.train()

# save deltas
save_dir = os.path.join(RUN_DIR, "embedding_only"); os.makedirs(save_dir, exist_ok=True)
torch.save(wte.weight.detach().cpu(), os.path.join(save_dir, "embed_tokens.pt"))
with open(os.path.join(save_dir, "added_tokens.json"), "w") as f:
    json.dump({"added_tokens": missing, "old_vocab": old_vocab, "num_added": num_added}, f, indent=2)
tok.save_pretrained(RUN_DIR)

# receipts
with torch.no_grad():
    post_full = wte.weight.detach().clone().cpu()
    post_new  = post_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, post_full.shape[1])

report = {
  "model": MODEL_NAME, "added_tokens": int(num_added), "old_vocab": int(old_vocab),
  "pip_loss_full": pip_loss(pre_full, post_full),
  "topk_overlap_all": topk_overlap(pre_full, post_full, k=10, sample=min(REPORT_SAMPLE, pre_full.shape[0]-1)),
  "isotropy_pre": isotropy(pre_full), "isotropy_post": isotropy(post_full),
  "pip_loss_new_rows": (pip_loss(pre_new, post_new) if num_added>1 else None),
  "isotropy_new_rows": (isotropy(post_new) if num_added>1 else None),
}
json.dump(report, open(os.path.join(RUN_DIR, "geometry_report.json"), "w"), indent=2)
print("\n=== GEOMETRY REPORT ===")
for k,v in report.items(): print(f"{k}: {v}")

nothing makes me smile more than "gosh honey you are still working with the Wake are you okay babes?"

and maybe "oh we’re doing violence to vectors today? say less."