<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/try_again_fail_again_fail_better.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2vec — Token Injection & Training


why do I feel so sad without t4 and so happy with it lol

In [None]:
!pip -q install transformers==4.43.3 accelerate peft==0.11.1 datasets umap-learn faiss-cpu matplotlib==3.8.4 sentencepiece

import os, re, math, random, json, unicodedata
from pathlib import Path
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# distilgpt2 for CPU smoke tests
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

In [None]:
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, device_map="auto")
print("Base vocab size:", tok.vocab_size)

In [None]:
# wake_lexicon.txt
lex_path = None
try:
    from google.colab import files
    uploaded = files.upload()
    lex_path = list(uploaded.keys())[0]
except Exception:
    lex_path = "/content/wake_lexicon.txt"

raw = []
p = Path(lex_path)
if p.exists():
    raw = p.read_text(encoding="utf-8", errors="ignore").splitlines()
else:
    print("WARNING: wake_lexicon.txt not found; provide a path or upload.")

seen = set(); new_terms = []
existing = tok.get_vocab()
for w in raw:
    t = unicodedata.normalize("NFC", w.strip())
    if t and (t not in seen) and (t not in existing):
        seen.add(t); new_terms.append(t)

print(f"Loaded {len(new_terms)} NEW tokens.")
print('Sample:', new_terms[:20])

In [None]:
added = tok.add_tokens(new_terms, special_tokens=False)
if added != len(new_terms):
    print(f"Note: only {added}/{len(new_terms)} newly added (collisions or tokenizer rules).")
model.resize_token_embeddings(len(tok))
print("New vocab size:", tok.vocab_size)

In [None]:
# come to clean the mess: add start-of-word variants for SentencePiece-style tokenizers and tie output head
START = "▁"
prefixed = [START + t for t in new_terms if not t.startswith(START)]
prefixed = [t for t in prefixed if t not in tok.get_vocab()]
added2 = tok.add_tokens(prefixed, special_tokens=False)
model.resize_token_embeddings(len(tok))
print(f"Added {added2} start-of-word variants. Vocab: {tok.vocab_size}")

# Tie lm_head to input embeddings so generations reflect updated embeddings
if model.get_output_embeddings() is not None:
    model.get_output_embeddings().weight = model.get_input_embeddings().weight
    print("Tied lm_head to input embeddings.")

In [None]:
emb = model.get_input_embeddings()

def init_vec(term: str, noise=0.02):
    ids = tok(term, add_special_tokens=False)["input_ids"]
    if ids:
        base = emb.weight[ids].mean(0, keepdim=True)
    else:
        base = torch.randn_like(emb.weight[0:1])
    return (base + noise * torch.randn_like(base)).squeeze(0)

with torch.no_grad():
    targets = new_terms + [("▁"+t) for t in new_terms if not t.startswith("▁")]
    for t in targets:
        idx = tok.convert_tokens_to_ids(t)
        if idx != tok.unk_token_id:
            emb.weight[idx].copy_(init_vec(t))
print("Initialised bare + ▁ variants.")

In [None]:
# Freeze almost everything, train embeddings (optional tiny LoRA for attention projections)
for p in model.parameters():
    p.requires_grad = False
model.get_input_embeddings().weight.requires_grad_(True)

USE_LORA = True  # if False isolates pure embedding drift
if USE_LORA:
    lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05,
                          target_modules=["q_proj","k_proj","v_proj","o_proj"],
                          bias="none", task_type="CAUSAL_LM")
    model = get_peft_model(model, lora_cfg)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params:", trainable)

In [None]:
# build Wake-saturated paragraph windows, split into train/val
fw_path = None
try:
    from google.colab import files
    uploaded = files.upload()
    fw_path = list(uploaded.keys())[0]
except Exception:
    fw_path = "/content/finnegans_wake.txt"

fw = Path(fw_path).read_text(encoding="utf-8", errors="ignore")
fw = unicodedata.normalize("NFC", fw).replace("\r\n","\n").replace("\r","\n")

import re, random
paras = [p.strip() for p in re.split(r"\n\s*\n", fw) if p.strip()]
probe = {t.lower() for t in new_terms if len(t) <= 40}

hits = [i for i,p in enumerate(paras) if any(t in p.lower() for t in probe)]
RADIUS = 2
windows = [" ".join(paras[max(0,i-RADIUS):min(len(paras), i+RADIUS+1)]) for i in hits]
windows = list(dict.fromkeys(windows))
random.shuffle(windows)
if len(windows) < 500:
    windows += random.sample(paras, k=min(2000, len(paras)))

split = int(0.9 * len(windows))
train_texts, val_texts = windows[:split], windows[split:]

def tok_fn(batch):
    return tok(batch["text"], truncation=True, max_length=512)

ds = DatasetDict({
    "train": Dataset.from_dict({"text": train_texts}).map(tok_fn, batched=True, remove_columns=["text"]),
    "validation": Dataset.from_dict({"text": val_texts}).map(tok_fn, batched=True, remove_columns=["text"]),
})
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
print("Train/Val sizes:", len(ds["train"]), len(ds["validation"]))

In [None]:
args = TrainingArguments(
    output_dir="/content/Wake2vec_adapter",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=6e-4,
    num_train_epochs=3.0,
    bf16=torch.cuda.is_available(),
    logging_steps=25,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    data_collator=collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
trainer.train()
trainer.save_model("/content/Wake2vec_adapter/final")
tok.save_pretrained("/content/Wake2vec_adapter/final_tok")
print("Saved adapter + tokenizer to /content/Wake2vec_adapter")

In [None]:
import unicodedata, re
from sklearn.metrics.pairwise import cosine_similarity
LETTER = re.compile(r"^[A-Za-zÀ-ÖØ-öø-ÿ][A-Za-zÀ-ÖØ-öø-ÿ'’-]*$")

def is_latin(ch):
    try: return ("LATIN" in unicodedata.name(ch)) or ch in "'’-"
    except ValueError: return False

def is_wordlike(tok):
    if "▁" in tok: return False
    if not LETTER.match(tok): return False
    return all(is_latin(c) for c in tok)

vocab = tok.get_vocab(); inv = {i:t for t,i in vocab.items()}
allowed_ids = [i for i,t in inv.items() if is_wordlike(t) or t in new_terms or ("▁"+t) in new_terms]

W = model.get_input_embeddings().weight.detach().cpu().numpy()
Wa = W[allowed_ids] / (np.linalg.norm(W[allowed_ids], axis=1, keepdims=True) + 1e-12)

def clean_neighbors(terms, k=8):
    out = {}
    for t in terms[:30]:
        tid = tok.convert_tokens_to_ids(t)
        if tid == tok.unk_token_id: continue
        v = W[tid:tid+1] / (np.linalg.norm(W[tid:tid+1], axis=1, keepdims=True) + 1e-12)
        sims = (v @ Wa.T)[0]
        top = np.argpartition(-sims, range(min(k, len(sims))))[:k]
        top = top[np.argsort(-sims[top])]
        out[t] = [inv[allowed_ids[j]] for j in top]
    return out

probe = new_terms[:12]
nn = clean_neighbors(probe, k=8)
for k,v in nn.items():
    print(f"{k:>20} -> {v}")

In [None]:
def complete(prompt, max_new_tokens=60):
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=1.15,
            top_p=0.95,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3
        )
    return tok.decode(out[0], skip_special_tokens=True)

tests = [
    "By the river I thought of",
    "At night I dream of",
    "In the story of this book,",
    "Explain gradient descent in the style of Joyce:"
]
for p in tests:
    print("\n===", p, "\n", complete(p))

In [None]:
# Export minimal adapter pack: changed embedding rows (bare + ▁ forms)
save_dir = Path("/content/Wake2vec_adapter/minipack"); save_dir.mkdir(parents=True, exist_ok=True)
changed = [tok.convert_tokens_to_ids(t) for t in (new_terms + ["▁"+t for t in new_terms if not t.startswith("▁")])]
changed = [i for i in changed if i != tok.unk_token_id]
emb_slice = model.get_input_embeddings().weight[changed].detach().cpu().numpy()
np.save(save_dir / "new_token_ids.npy", np.array(changed, dtype=np.int32))
np.save(save_dir / "new_token_vectors.npy", emb_slice)
(Path(save_dir / "README.txt").write_text(
    "Rows from input embedding for Wake2vec tokens (bare + ▁forms).\n"
    "Apply to base model by index assignment.\n", encoding="utf-8"))
print("Mini pack saved:", save_dir)