<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2Vec_three_cell.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

set up & data

In [None]:
!pip install transformers accelerate datasets adafactor bitsandbytes scikit-learn umap-learn matplotlib
from pathlib import Path
import json, math, random, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

ROOT = Path(".")
FW = (ROOT/"data"/"FW_TEXT.txt").read_text(encoding="utf-8")
MORPHEMES = [l.strip().split(",") for l in (ROOT/"data"/"morphemes.csv").read_text().splitlines()[1:]]
prefixes = {}
suffixes = {}
for pfx,sfx,ex in MORPHEMES:
    if pfx: prefixes.setdefault(pfx, []).append(ex)
    if sfx: suffixes.setdefault(sfx, []).append(ex)

def synthetic_words(n=1200, roots=("river thunder word sound dance queen storm tree night sun rain book").split()):
    out=set()
    for _ in range(n*2):
        p = random.choice(list(prefixes.keys()))
        s = random.choice(list(suffixes.keys()))
        r = random.choice(roots)
        out.add(f"{p}{r}{s}")
    return list(out)[:n]

syn = synthetic_words()
base = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(base, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.float32, device_map="auto")

expand tokenizer & compose embeddings

In [None]:
def avg_vec(terms, emb, tok):
    vecs=[]
    for t in terms:
        ids = tok.encode(t, add_special_tokens=False)
        if len(ids)==1:
            vecs.append(emb.weight.data[ids[0]])
    return torch.stack(vecs).mean(0) if vecs else None

new_tokens = [w for w in syn if all(len(tok.encode(w, add_special_tokens=False))>1 for _ in [0])]
added = tok.add_tokens(new_tokens, special_tokens=False)
model.resize_token_embeddings(len(tok), mean_resizing=False)

emb = model.get_input_embeddings()
with torch.no_grad():
    alpha = 0.25
    std = emb.weight.data.std().item()
    for w in new_tokens:
        # crude split guess: find first prefix and last suffix match
        p = next((p for p in prefixes if w.startswith(p)), None)
        s = next((s for s in suffixes if w.endswith(s)), None)
        root = w[len(p):len(w)-len(s)] if (p and s and len(w)>len(p)+len(s)) else w
        vp = avg_vec(prefixes.get(p, []), emb, tok)
        vs = avg_vec(suffixes.get(s, []), emb, tok)
        vr_ids = tok.encode(root, add_special_tokens=False)
        vr = emb.weight.data[vr_ids[0]] if len(vr_ids)==1 else torch.randn(emb.embedding_dim)*std*0.5
        comp = (alpha*(vp if vp is not None else vr)
                + (1-2*alpha)*vr
                + alpha*(vs if vs is not None else vr))
        comp = comp + torch.randn_like(comp)*std*0.01
        emb.weight.data[tok.convert_tokens_to_ids(w)] = comp

# tie head
with torch.no_grad():
    model.lm_head.weight = emb.weight

two-phase training + quick eval

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset

def make_blocks(text, max_len=2048, stride=1024):
    ids = tok.encode(text, add_special_tokens=False)
    blocks=[]
    for i in range(0, len(ids)-max_len, stride):
        blocks.append({"input_ids": ids[i:i+max_len]})
    return blocks

train_text = "\n".join(random.sample(syn, k=400)) + "\n" + FW[:600000]
valid_text = FW[600000:630000]

train_ds = Dataset.from_list(make_blocks(train_text))
valid_ds = Dataset.from_list(make_blocks(valid_text))

dc = DataCollatorForLanguageModeling(tok, mlm=False)

def freeze_all_but_embeddings(m):
    for p in m.parameters(): p.requires_grad=False
    for p in m.get_input_embeddings().parameters(): p.requires_grad=True
    for p in m.lm_head.parameters(): p.requires_grad=True

# embeddings only
freeze_all_but_embeddings(model)
model.config.use_cache = False
args1 = TrainingArguments(
    output_dir="runs/phase1",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=5e-4,
    num_train_epochs=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=200,
    eval_steps=200,
    logging_steps=50,
    load_best_model_at_end=False,
    report_to="none",
)
Trainer(model=model, args=args1, train_dataset=train_ds, eval_dataset=valid_ds, data_collator=dc).train()

# fine-tune
for p in model.parameters(): p.requires_grad=True
args2 = TrainingArguments(
    output_dir="runs/phase2",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=2e-5,
    num_train_epochs=2,
    warmup_ratio=0.10,
    weight_decay=0.01,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=200,
    eval_steps=200,
    logging_steps=50,
    gradient_checkpointing=True,
    fp16=False,
    load_best_model_at_end=True,
    report_to="none",
)
Trainer(model=model, args=args2, train_dataset=train_ds, eval_dataset=valid_ds, data_collator=dc).train()

# quick neighbor diag (top-5 overlap proxy)
with torch.no_grad():
    W = model.get_input_embeddings().weight.detach().cpu()
    from sklearn.metrics.pairwise import cosine_similarity
    import numpy as np, json
    ids_new = [tok.convert_tokens_to_ids(t) for t in new_tokens]
    sim = cosine_similarity(W[ids_new], W)
    top5 = np.argsort(-sim, axis=1)[:,1:6]
    stats = {"new_token_count": len(new_tokens), "example_top5_ids": top5[:5].tolist()}
    Path("runs/phase2/metrics").mkdir(parents=True, exist_ok=True)
    (Path("runs/phase2/metrics")/"quick_neighbors.json").write_text(json.dumps(stats, indent=2))
print("Saved runs/phase2/metrics/quick_neighbors.json")