<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/token_injection_and_training_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2vec Token Injection & Training


In [1]:
from google.colab import drive; drive.mount('/content/drive')
from datetime import datetime
from pathlib import Path

RUN_ID = datetime.now().strftime("wake2vec_%Y%m%d_%H%M")
BASE = Path(f"/content/drive/MyDrive/Wake2vec_runs/{RUN_ID}")
ADAPTER_DIR = BASE/'adapter'; RESULTS = BASE/'results'
for p in (ADAPTER_DIR, RESULTS): p.mkdir(parents=True, exist_ok=True)
print("Saving to:", BASE)

Mounted at /content/drive
Saving to: /content/drive/MyDrive/Wake2vec_runs/wake2vec_20251028_2148


model + tok

In [2]:
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # CPU fallback: "distilgpt2"
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Add tokens, resize, tie head

In [8]:
# Load lex
from pathlib import Path, PurePath
LEX_PATH = Path("/content/drive/MyDrive/Wake2vec_runs/wake_lexicon.txt")  # adjust if needed
lex = [t.strip() for t in LEX_PATH.read_text(encoding="utf-8").splitlines() if t.strip()]

new_terms = []
for t in lex:
    if not t: continue
    new_terms.append(t)
    if not t.startswith("▁"):
        new_terms.append("▁"+t)

# add tokens
added = tok.add_tokens(list(dict.fromkeys(new_terms)))
print("Added tokens:", added)
model.resize_token_embeddings(len(tok))

# tie lm_head to input embeddings
if model.get_output_embeddings() is not None:
    model.get_output_embeddings().weight = model.get_input_embeddings().weight
    # after tok.add_tokens(...)
try:
    model.resize_token_embeddings(len(tok), mean_resizing=False)
except TypeError:
    # older Transformers without the flag
    model.resize_token_embeddings(len(tok))

# immediately tie head to inputs
with torch.no_grad():
    if model.get_output_embeddings() is not None:
        model.get_output_embeddings().weight = model.get_input_embeddings().weight


Added tokens: 0


Make embeddings trainable

In [9]:
# freeze everything
for p in model.parameters():
    p.requires_grad = False

# unfreeze input embeddings (KEY)
emb = model.get_input_embeddings()
emb.weight.requires_grad_(True)

# Opt LoRA tiny
USE_LORA = True
if USE_LORA:
    from peft import LoraConfig, get_peft_model, TaskType
    lcfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj","k_proj","v_proj","o_proj"]
    )
    model = get_peft_model(model, lcfg)
    model.print_trainable_parameters()
def enable_embeddings_and_lora_only(model):
    # freeze everything
    for _, p in model.named_parameters():
        p.requires_grad = False

    # turn on LoRA params
    for name, p in model.named_parameters():
        if "lora_" in name:
            p.requires_grad = True

    # turn on input embeddings
    emb = model.get_input_embeddings()
    emb.weight.requires_grad_(True)

    # report
    total = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    emb_train = emb.weight.numel() if emb.weight.requires_grad else 0
    print(f"Total params: {total:,} | Trainable: {train:,}")
    print(f"Embeddings trainable? {emb.weight.requires_grad} | embed rows: {emb_train:,}")

# AFTER LoRA wrap & AFTER any tokenizer resize/tie
enable_embeddings_and_lora_only(model)

# trainability report
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total:,} | Trainable: {trainable:,}")
print("Embeddings trainable?", emb.weight.requires_grad)



trainable params: 2,252,800 || all params: 1,219,211,264 || trainable%: 0.1848
Total params: 1,219,211,264 | Trainable: 250,234,880
Embeddings trainable? True | embed rows: 247,982,080
Total params: 1,219,211,264 | Trainable: 250,234,880
Embeddings trainable? True


In [10]:
for n,p in model.named_parameters():
    if p.requires_grad and ("embed" in n or "lora_" in n):
        print(" •", n, p.shape)

 • base_model.model.base_model.model.model.embed_tokens.weight torch.Size([121085, 2048])
 • base_model.model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([8, 2048])
 • base_model.model.base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([2048, 8])
 • base_model.model.base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight torch.Size([8, 2048])
 • base_model.model.base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight torch.Size([256, 8])
 • base_model.model.base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([8, 2048])
 • base_model.model.base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([256, 8])
 • base_model.model.base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight torch.Size([8, 2048])
 • base_model.model.base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight torch.Size([2048, 8])
 • base_

In [11]:
baseline_dir = RESULTS / "baseline_now"
baseline_dir.mkdir(parents=True, exist_ok=True)

anchors = ["river","history","book","dream","night","language","irish","dublin","cork"]
probe = [t for t in new_terms[:150] if tok.convert_tokens_to_ids(t)!=tok.unk_token_id] + anchors
ids = [tok.convert_tokens_to_ids(t) for t in probe]
W = model.get_input_embeddings().weight.detach().cpu().numpy()
import numpy as np, json
P0 = W[ids] @ W[ids].T
np.save(baseline_dir/"probe_ids.npy", np.array(ids, dtype=np.int32))
np.save(baseline_dir/"P0.npy", P0)
(baseline_dir/"probe_tokens.json").write_text(json.dumps(probe, ensure_ascii=False, indent=2))
print("Baseline saved to", baseline_dir)

Baseline saved to /content/drive/MyDrive/Wake2vec_runs/wake2vec_20251028_2148/results/baseline_now


Wake-dense dataset

In [12]:
from datasets import Dataset, DatasetDict
import re, unicodedata

RAW_WAKE = Path("/content/drive/MyDrive/Wake2vec_runs/fw.txt").read_text(encoding="utf-8")
txt = unicodedata.normalize("NFC", RAW_WAKE)
# anchor at 'riverrun'
i = txt.lower().find("riverrun")
if i > 0: txt = txt[i:]

paras = [re.sub(r"\s+"," ",p.strip()) for p in re.split(r"\n\s*\n", txt) if len(p.split())>6]
RADIUS = 2
windows = []
for idx in range(len(paras)):
    lo, hi = max(0, idx-RADIUS), min(len(paras), idx+RADIUS+1)
    windows.append(" ".join(paras[lo:hi]))

def tok_map(batch):
    return tok(batch["text"], truncation=True, max_length=512, padding=False)

split = int(0.95*len(windows))
ds = DatasetDict({
    "train": Dataset.from_dict({"text": windows[:split]}).map(tok_map, batched=True, remove_columns=["text"]),
    "validation": Dataset.from_dict({"text": windows[split:]}).map(tok_map, batched=True, remove_columns=["text"]),
})
from transformers import default_data_collator as collator
print(ds)

Map:   0%|          | 0/1456 [00:00<?, ? examples/s]

Map:   0%|          | 0/77 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1456
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 77
    })
})


snapshot before training

In [13]:
import numpy as np, json, time
BAS_DIR = RESULTS / f"baseline_{int(time.time())}"
BAS_DIR.mkdir(parents=True, exist_ok=True)

anchors = ["river","history","book","dream","night","language","irish","dublin","cork"]
probe_new = [t for t in lex if tok.convert_tokens_to_ids(t)!=tok.unk_token_id][:150]
probe = probe_new + anchors
ids = [tok.convert_tokens_to_ids(t) for t in probe]

W = model.get_input_embeddings().weight.detach().cpu().numpy()
P0 = W[ids] @ W[ids].T
np.save(BAS_DIR/"probe_ids.npy", np.array(ids, dtype=np.int32))
np.save(BAS_DIR/"P0.npy", P0)
(BAS_DIR/"probe_tokens.json").write_text(json.dumps(probe, ensure_ascii=False, indent=2))
print("Baseline saved to", BAS_DIR)

Baseline saved to /content/drive/MyDrive/Wake2vec_runs/wake2vec_20251028_2148/results/baseline_1761689564


Train

In [15]:
emb = model.get_input_embeddings().weight
print("Embeddings trainable?", emb.requires_grad)

# how many of the *new* tokens exist + will learn?
new_ids = [tok.convert_tokens_to_ids(t) for t in new_terms]
new_ids = sorted({i for i in new_ids if i != tok.unk_token_id})
print("New token rows:", len(new_ids))
print("Added trainable params from embeddings ~", len(new_ids)*emb.shape[1])

Embeddings trainable? True
New token rows: 89980
Added trainable params from embeddings ~ 184279040


In [20]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir=str(ADAPTER_DIR),
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=6e-4,
    num_train_epochs=3,
    fp16=True,
    logging_steps=25,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,

    lr_scheduler_type="cosine",
    warmup_ratio=0.10,
    max_grad_norm=1.0,
    eval_accumulation_steps=16,
    gradient_checkpointing=True,
    report_to=[]
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    data_collator=collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
trainer.train()
# save final to Drive
trainer.save_model(str(ADAPTER_DIR/"final"))
tok.save_pretrained(str(ADAPTER_DIR/"final_tok"))
print("Saved to", ADAPTER_DIR)

The model is already on multiple devices. Skipping the move to device specified in `args`.


ValueError: expected sequence of length 512 at dim 1 (got 510)

Metrics + Hero

In [None]:
# retie for generation
if model.get_output_embeddings() is not None:
    model.get_output_embeddings().weight = model.get_input_embeddings().weight
model.config.use_cache = True

import numpy as np, pandas as pd, json, glob
import umap, matplotlib.pyplot as plt, textwrap, os

# load most recent baseline from RESULTS
bdirs = sorted([p for p in RESULTS.glob("baseline_*")], key=lambda p: p.stat().st_mtime)
ids = np.load(bdirs[-1]/"probe_ids.npy"); P0 = np.load(bdirs[-1]/"P0.npy")
W = model.get_input_embeddings().weight.detach().cpu().numpy()
P1 = W[ids] @ W[ids].T

pip = float(np.linalg.norm(P0 - P1, 'fro') / (np.linalg.norm(P0, 'fro') + 1e-12))
X = W[ids] - W[ids].mean(0, keepdims=True)
C = (X.T @ X) / max(1, len(ids)-1)
evals = np.clip(np.linalg.eigvalsh(C), 1e-12, None)
isotropy = float(np.exp(np.log(evals).mean()) / evals.mean())
topk=10
N0 = np.argsort(-P0,axis=1)[:,:topk]; N1 = np.argsort(-P1,axis=1)[:,:topk]
overlap = float(np.mean([len(set(N0[i]).intersection(set(N1[i]))) / topk for i in range(len(ids))]))
pd.DataFrame([{"pip_loss":pip,"isotropy":isotropy,"top10_overlap":overlap,"num_probe_tokens":int(len(ids))}]).to_csv(RESULTS/"metrics_summary.csv", index=False)
print("Wrote", RESULTS/"metrics_summary.csv")

# Hero
probe_show = list(dict.fromkeys(lex[:40] + ["river","history","book","dream","night","language","irish","dublin"]))
probe_ids = [tok.convert_tokens_to_ids(t) for t in probe_show if tok.convert_tokens_to_ids(t)!=tok.unk_token_id]
vecs = model.get_input_embeddings().weight[probe_ids].detach().cpu().numpy()
xy = umap.UMAP(n_neighbors=8, min_dist=0.1, metric="cosine", random_state=42).fit_transform(vecs)

def complete_portfolio(prompt, max_new_tokens=120):
    ins = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**ins, max_new_tokens=max_new_tokens, do_sample=True,
                             temperature=0.90, top_p=0.88, repetition_penalty=1.2,
                             no_repeat_ngram_size=4, cache_implementation="static")
    return tok.decode(out[0], skip_special_tokens=True)

samples = {
  "By the river I thought of — Earwicker, the prankquean, and": complete_portfolio("By the river I thought of — Earwicker, the prankquean, and", 120),
  "In the story of this book, HCE remembers that": complete_portfolio("In the story of this book, HCE remembers that", 120),
  "Explain gradient descent in the style of Joyce: riverrun,": complete_portfolio("Explain gradient descent in the style of Joyce: riverrun,", 120),
}

plt.figure(figsize=(10,6))
plt.subplot(1,2,1); plt.scatter(xy[:,0], xy[:,1])
inv = {i:t for t,i in tok.get_vocab().items()}
for pid,(x,y) in zip(probe_ids, xy): plt.text(x,y,inv.get(int(pid),"."),fontsize=8)
plt.title("Wake2vec: neighborhood map (UMAP)")
plt.subplot(1,2,2); y0=1.0
for k,v in samples.items():
    plt.text(0.0,y0,k+"\n"+textwrap.fill(v.replace("\n"," "), width=44), fontsize=8, va='top'); y0-=0.32
plt.axis('off'); plt.title("Sample completions")
plt.tight_layout(); plt.savefig(RESULTS/"hero.png", dpi=200)
print("Saved", RESULTS/"hero.png")

tbc