<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2Vec_morpheme_expansion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2Vec Morpheme Expansion Pipeline

This notebook documents a controlled procedure for integrating Joyce-style neologisms into a compact GPT-type language model through morphology-aware token expansion. I curate a small lexicon of prefixes and suffixes and generate synthetic candidates, then extend the tokenizer to admit previously split neologisms as single tokens. New embeddings are initialised by morphemic composition, using the rule \(E(\text{word}) = \alpha\,E(\text{prefix}) + (1 - 2\alpha)\,E(\text{root}) + \alpha\,E(\text{suffix}) + \varepsilon\), where \(\alpha\) is a fixed weight and \(\varepsilon\) is small Gaussian noise that prevents identical vectors. Training proceeds in two stages: an embedding-only warm-up on a mixture of synthetic lines and Finnegans *Wake* text, followed by a short full-model fine-tune under conservative schedules suitable for a T4 environment.

 I report top-five neighbor overlap for the newly introduced tokens before and after training, track shifts in embedding norms, provide a t-SNE projection of the new tokens against pre-training neighbor centroids, and save JSON snapshots of neighborhoods at each stage. These diagnostics are intended to show coherent integration of the new forms into the embedding space rather than collapse or runaway drift, and to make the procedure straightforward to reproduce on modest hardware.

**Config**

Base model: `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`. Composition weight \(\alpha = 0.25\). Maximum sequence length set to 1024 to respect T4 memory limits. Batching uses `per_device_train_batch_size = 1` with `gradient_accumulation_steps = 8`, attention implementation set to `eager`, and `use_cache = False`. Phase 1 trains input embeddings and the tied output head only; Phase 2 unfreezes all parameters with a warm-up ratio of 0.10 and light weight decay. All runs write plots and machine-readable artifacts to `runs/<RUN_ID>/` and generate a brief HTML report.

---

## Run controls
- **BASE_MODEL:** `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`
- **α (composition weight):** `0.25` (can tune)
- **Max seq length:** `1024` (T4-safe; raise only if VRAM allows)
- **Batching:** `per_device_train_batch_size=1`, `gradient_accumulation_steps=8`
- **Attn impl:** `eager` (avoid SDPA spikes on T4)
- **Two phases:**
  - **Phase 1:** embeddings + lm_head only, Adafactor/8-bit Adam, 1 epoch
  - **Phase 2:** full model, short run, warmup 0.10

## Inputs
- `data/FW_TEXT.txt` — Finnegans Wake plain text (slice for demo)
- `data/morpheme_data.json` or `data/morphemes.csv`  
  Structure maps:
  - `prefixes`: `{ prefix → [example words…] }`
  - `suffixes`: `{ suffix → [example words…] }`

## Outputs (per run)
- `runs/<RUN_ID>/metrics/`
  - `pre_morpheme_snapshot.json`
  - `morpheme_comparison_p1.json` *(midpoint, after Phase 1)*
  - `morpheme_comparison.json` *(final, after Phase 2)*
  - `summary_stats_p1.json`, `summary_stats.json`
- `runs/<RUN_ID>/plots/`
  - `hist_overlap_top5(_p1).png`, `hist_norm_change(_p1).png`
  - `scatter_norm_vs_overlap.png`, `tsne_newtokens_vs_precentroids.png`
- `reports/Wake2Vec_Report.html`

## Quickstart
1. **Reset & install** deps (Colab-friendly).  
2. **Load data** (prefers JSON).  
3. **Generate** synthetic forms (prefix + root + suffix).  
4. **Expand tokenizer** (add new tokens); compose embeddings with α-rule; tie head.  
5. **Phase 1**: train embeddings only. Saves midpoint snapshot.
6. **Phase 2**: unfreeze and short fine-tune.  
7. **Diagnostics**: compute overlap@5, norm deltas, t-SNE; write HTML report.  


## Diagnostics (what “good” looks like)
- **Top-5 neighbor overlap (pre→post):** ~3–4/5 indicates coherent integration (not collapse).
- **Norm shift (Δ‖E‖):** small positive mean (slight energy increase from training).
- **Qualitative neighbors:** morpheme-aligned (e.g., `presounder` ≈ `resound`, `ensounder`, …).
- **Tokenization:** most synthetic forms now **single IDs**.

## Repro & env
- `RUN_ID = "t4_<unix>"` auto-stamped; seeds fixed at 42.
- Tested on Colab T4 with: `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.
- T4 guardrails: `MAX_LEN=1024`, `gradient_checkpointing=True`, attention=`eager`, batch=1 + accum=8.

## Troubleshooting (T4)
- **CUDA OOM** → lower `MAX_LEN` to 768/512; keep batch=1; accum=8–16; ensure `use_cache=False`; `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`.
- **Version noise** → uninstall RAPIDS/TF; pin `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.

---

 *Wake2Vec tests morphology-aware token expansion to integrate Joyce-style neologisms into a small language model without destabilising the embedding space. We curate a prefix/suffix lexicon, generate synthetic forms, initialise new vectors by morpheme composition, and train in two phases. Evaluation reports neighbor-overlap@5, embedding-norm shifts, and qualitative neighborhoods, with JSON snapshots for reproducibility.*


In [15]:
!pip -q uninstall -y cudf-cu12 pylibcudf-cu12 cuml-cu12 dask-cudf-cu12 cupy-cuda12x tensorflow opencv-python-headless opencv-contrib-python opencv-python >/dev/null

!pip -q install --no-cache-dir --upgrade-strategy eager \
  "transformers==4.57.1" "accelerate>=0.33" "tokenizers>=0.15" "safetensors" \
  "datasets==2.21.0" "evaluate>=0.4.0" "pyarrow==22.0.0" \
  "huggingface-hub>=0.34,<1.0" "bitsandbytes>=0.43" \
  "umap-learn" "faiss-cpu" "wordfreq" "Unidecode" "matplotlib" "scikit-learn"

import transformers, datasets, pyarrow
print("Transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("pyarrow:", pyarrow.__version__)

[0mTransformers: 4.57.1
datasets: 2.21.0
pyarrow: 22.0.0


Imports, seeds, run IDs, paths

In [27]:
import os, json, time, random
from pathlib import Path
import numpy as np
import torch

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

RUN_ID = f"t4_{int(time.time())}"
ROOT = Path("/content")
RUN_DIR = ROOT / "runs" / RUN_ID
PLOTS_DIR = RUN_DIR / "plots"
METRICS_DIR = RUN_DIR / "metrics"
REPORTS_DIR = ROOT / "reports"
for p in (PLOTS_DIR, METRICS_DIR, REPORTS_DIR): p.mkdir(parents=True, exist_ok=True)

META = {
    "run_id": RUN_ID, "seed": SEED, "alpha": 0.25,
    "phase1": {"lr": 5e-4, "epochs": 1, "ptd_bs": 8, "grad_accum": 2},
    "phase2": {"lr": 2e-5, "epochs": 2, "warmup_ratio": 0.10, "ptd_bs": 8, "grad_accum": 2, "weight_decay": 0.01}
}
(METRICS_DIR/"meta.json").write_text(json.dumps(META, indent=2))
print("RUN_ID:", RUN_ID)

RUN_ID: t4_1761966609


Load data

In [28]:
from pathlib import Path
import json, csv

DATA_DIR = ROOT/"data"; DATA_DIR.mkdir(parents=True, exist_ok=True)
FW_PATH   = DATA_DIR/"/content/FW_TEXT.txt"
JSON_PATH = DATA_DIR/"/content/morpheme_data.json"
CSV_PATH  = DATA_DIR/"/content/morphemes.csv"

def load_morpheme_csv(path):
    d = {"prefixes": {}, "suffixes": {}}
    with open(path, newline="", encoding="utf-8") as f:
        rdr = csv.reader(f); header = next(rdr, None)
        for row in rdr:
            if not row: continue
            typ, morpheme, *examples = [x.strip() for x in row]
            if typ not in ("prefix","suffix"): continue
            key = "prefixes" if typ=="prefix" else "suffixes"
            ex = [w for w in dict.fromkeys(examples) if w]
            if ex: d[key][morpheme] = ex
    return d

if JSON_PATH.exists():
    MORPHEME_DATA = json.load(open(JSON_PATH, "r", encoding="utf-8"))
elif CSV_PATH.exists():
    MORPHEME_DATA = load_morpheme_csv(CSV_PATH)
else:
    raise FileNotFoundError("Put morpheme_data.json or morphemes.csv in /content/data")

prefixes = MORPHEME_DATA.get("prefixes", {})
suffixes = MORPHEME_DATA.get("suffixes", {})

if not FW_PATH.exists():
    FW_PATH.write_text("Placeholder FW text.\n"*5000, encoding="utf-8")
FW_TEXT = FW_PATH.read_text(encoding="utf-8")

print(f"Prefixes: {len(prefixes)} | Suffixes: {len(suffixes)} | FW chars: {len(FW_TEXT):,}")

Prefixes: 15 | Suffixes: 15 | FW chars: 1,364,712


Synthetic generator

In [29]:
import random
def synthetic_words(n=1200, roots=("river thunder word sound dance queen storm tree night sun rain book".split())):
    out=set()
    pfx_pool=[p for p,ex in prefixes.items() for _ in range(max(1,len(ex)//2+1))]
    sfx_pool=[s for s,ex in suffixes.items() for _ in range(max(1,len(ex)//2+1))]
    for _ in range(max(2*n, 2000)):
        if not pfx_pool or not sfx_pool: break
        p=random.choice(pfx_pool); s=random.choice(sfx_pool); r=random.choice(roots)
        if len(p)+len(r)+len(s)>3: out.add(f"{p}{r}{s}")
        if len(out)>=n: break
    return sorted(out)

SYN_WORDS = synthetic_words()
SYN_LINES = [f"The {w} rolled down the river at night." for w in random.sample(SYN_WORDS, min(400,len(SYN_WORDS)))]
print("Synthetic words:", len(SYN_WORDS), "| synthetic lines:", len(SYN_LINES))

Synthetic words: 1200 | synthetic lines: 400


base model, expand tokenizer, compose embeddings, tie head


In [30]:
from transformers import AutoTokenizer, AutoModelForCausalLM

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, dtype="float32", device_map="auto")

pre_token_splits = {w: tok.encode(w, add_special_tokens=False) for w in SYN_WORDS}
new_tokens = [w for w, ids in pre_token_splits.items() if len(ids) > 1]
added = tok.add_tokens(new_tokens, special_tokens=False)
model.resize_token_embeddings(len(tok), mean_resizing=False)
print(f"Added tokens: {added} | Vocab size: {len(tok)}")

import torch
def avg_vec(terms, emb, tok):
    vecs=[]
    for t in terms:
        ids = tok.encode(t, add_special_tokens=False)
        if len(ids)==1: vecs.append(emb.weight.data[ids[0]])
    return torch.stack(vecs,0).mean(0) if vecs else None

with torch.no_grad():
    emb = model.get_input_embeddings()
    alpha = META["alpha"]; std = emb.weight.data.std().item()
    for w in new_tokens:
        p = next((p for p in prefixes if w.startswith(p)), None)
        s = next((s for s in suffixes if w.endswith(s)), None)
        root = w[len(p):len(w)-len(s)] if (p and s and len(w)>len(p)+len(s)) else w
        vp = avg_vec(prefixes.get(p, []), emb, tok)
        vs = avg_vec(suffixes.get(s, []), emb, tok)
        vr_ids = tok.encode(root, add_special_tokens=False)
        vr = emb.weight.data[vr_ids[0]] if len(vr_ids)==1 else torch.randn(emb.embedding_dim, device=emb.weight.device)*(std*0.5)
        comp = alpha*(vp if vp is not None else vr) + (1-2*alpha)*vr + alpha*(vs if vs is not None else vr)
        comp = comp + torch.randn_like(comp)*(std*0.01)
        emb.weight.data[tok.convert_tokens_to_ids(w)] = comp
    model.lm_head.weight = emb.weight

print("Composed embeddings + tied head.")

Added tokens: 1200 | Vocab size: 33200
Composed embeddings + tied head.


In [31]:
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
model.config.pad_token_id = tok.pad_token_id

blocks + PRE snapshot

In [32]:
from datasets import Dataset
from transformers import DataCollatorForLanguageModeling
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

MAX_LEN=1024; STRIDE=512
def make_blocks(text, max_len=MAX_LEN, stride=STRIDE):
    ids = tok.encode(text, add_special_tokens=False)
    return [{"input_ids": ids[i:i+max_len]} for i in range(0, max(0,len(ids)-max_len), stride) if len(ids[i:i+max_len])==max_len]

train_text = "\n".join(SYN_LINES) + "\n" + FW_TEXT[:600_000]
valid_text = FW_TEXT[600_000:630_000]
train_ds = Dataset.from_list(make_blocks(train_text))
valid_ds = Dataset.from_list(make_blocks(valid_text))
dc = DataCollatorForLanguageModeling(tok, mlm=False)

print("Train blocks:", len(train_ds), "| Valid blocks:", len(valid_ds))

with torch.no_grad():
    W_pre = model.get_input_embeddings().weight.detach().clone().to("cpu").numpy()
    new_ids = [tok.convert_tokens_to_ids(t) for t in new_tokens]
    sim_pre = cosine_similarity(W_pre[new_ids], W_pre)
    top5_pre = np.argsort(-sim_pre, axis=1)[:,1:6]

json.dump({"new_tokens": new_tokens, "top5_pre": top5_pre[:50].tolist()}, open(METRICS_DIR/"pre_morpheme_snapshot.json","w"), indent=2)
print("Saved pre snapshot.")

Train blocks: 372 | Valid blocks: 16
Saved pre snapshot.


P1 — embeddings-only warm-up

In [33]:
import os, torch, gc
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
gc.collect(); torch.cuda.empty_cache()

In [35]:
# P1 Embedding-only warm-up
import gc, torch
from transformers import Trainer, TrainingArguments

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
gc.collect(); torch.cuda.empty_cache()

# Freeze everything except embeddings and lm_head
def freeze_all_but_embeddings(m):
    for p in m.parameters():
        p.requires_grad = False
    for p in m.get_input_embeddings().parameters():
        p.requires_grad = True
    for p in m.lm_head.parameters():
        p.requires_grad = True

freeze_all_but_embeddings(model)

# Trainer args — tiny batch, big accum, checkpointing
model.config.use_cache = False
args1 = TrainingArguments(
    output_dir=str(RUN_DIR/"phase1"),
    per_device_train_batch_size=1,     # tiny batch
    gradient_accumulation_steps=8,
    learning_rate=META["phase1"]["lr"],
    num_train_epochs=META["phase1"]["epochs"],
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=400,
    eval_steps=400,
    logging_strategy="steps",
    logging_steps=50,
    gradient_checkpointing=True,
    fp16=False,                        # fp16 fragile on T4
    load_best_model_at_end=False,
    report_to="none",
    optim="adamw_bnb_8bit",            # use 8-bit Adam if bitsandbytes is present
)

trainer1 = Trainer(
    model=model,
    args=args1,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
)

trainer1.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss,Validation Loss


TrainOutput(global_step=47, training_loss=6.399282252534907, metrics={'train_runtime': 1083.9698, 'train_samples_per_second': 0.343, 'train_steps_per_second': 0.043, 'total_flos': 2214661416026112.0, 'train_loss': 6.399282252534907, 'epoch': 1.0})

In [36]:
# MID SNAPSHOT
import json, torch, numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
import matplotlib.pyplot as plt

MID_DIR = METRICS_DIR  # reuse same folder

with torch.no_grad():
    W_mid = model.get_input_embeddings().weight.detach().clone().to("cpu").numpy()
    sim_mid = cosine_similarity(W_mid[new_ids], W_mid)
    top5_mid = np.argsort(-sim_mid, axis=1)[:,1:6]

def overlap_at5(a,b): return len(set(a.tolist()) & set(b.tolist()))

overlaps_p1 = np.array([overlap_at5(top5_pre[i], top5_mid[i]) for i in range(len(new_ids))])
norms_pre   = np.linalg.norm(W_pre[new_ids], axis=1)
norms_mid   = np.linalg.norm(W_mid[new_ids], axis=1)
norm_deltas_p1 = norms_mid - norms_pre

summary_p1 = {
    "phase": "phase1",
    "compared_tokens": int(len(new_ids)),
    "mean_top5_overlap": float(np.mean(overlaps_p1)) if len(overlaps_p1) else None,
    "mean_norm_delta": float(np.mean(norm_deltas_p1)) if len(norm_deltas_p1) else None,
}

# save JSONs
(Path(MID_DIR)/"morpheme_comparison_p1.json").write_text(
    json.dumps({
        "top5_pre": top5_pre.tolist(),
        "top5_mid": top5_mid.tolist(),
        "overlap@5": overlaps_p1.tolist(),
        "norm_deltas": norm_deltas_p1.tolist(),
    }, indent=2)
)
(Path(MID_DIR)/"summary_stats_p1.json").write_text(json.dumps(summary_p1, indent=2))
print("Phase-1 summary:", summary_p1)

# quick plots
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
plt.figure(); plt.hist(overlaps_p1, bins=[-0.5,0.5,1.5,2.5,3.5,4.5,5.5])
plt.title("Top-5 overlap (PRE → MID)"); plt.xlabel("Overlap"); plt.ylabel("Freq")
plt.tight_layout(); plt.savefig(PLOTS_DIR/"hist_overlap_top5_p1.png", dpi=180); plt.close()

plt.figure(); plt.hist(norm_deltas_p1, bins=30)
plt.title("Embedding norm change (MID − PRE)"); plt.xlabel("Δ norm"); plt.ylabel("Freq")
plt.tight_layout(); plt.savefig(PLOTS_DIR/"hist_norm_change_p1.png", dpi=180); plt.close()

Phase-1 summary: {'phase': 'phase1', 'compared_tokens': 1200, 'mean_top5_overlap': 3.1283333333333334, 'mean_norm_delta': 0.07451333105564117}


P2 full-model fine-tune

In [37]:
# capture init for new ids
with torch.no_grad():
    E_init = model.get_input_embeddings().weight.data.clone()

# custom loss wrapper
LAMBDA = 1e-4
def add_anchor_loss(outputs, inputs):
    input_ids = inputs["input_ids"]
    emb = model.get_input_embeddings().weight
    ids = torch.unique(input_ids)
    ids = ids[ids >= 0]
    return LAMBDA * (emb[ids] - E_init[ids]).pow(2).mean()

In [39]:
MAX_LEN = 768
STRIDE  = 384

def make_blocks(text, max_len=MAX_LEN, stride=STRIDE):
    ids = tok.encode(text, add_special_tokens=False)
    return [{"input_ids": ids[i:i+max_len]} for i in range(0, max(0,len(ids)-max_len), stride) if len(ids[i:i+max_len])==max_len]

train_text = "\n".join(SYN_LINES) + "\n" + FW_TEXT[:600_000]
valid_text = FW_TEXT[600_000:630_000]
train_ds = Dataset.from_list(make_blocks(train_text))
valid_ds = Dataset.from_list(make_blocks(valid_text))

from transformers import DataCollatorForLanguageModeling
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
model.config.pad_token_id = tok.pad_token_id
dc = DataCollatorForLanguageModeling(tok, mlm=False)

len(train_ds), len(valid_ds)


(497, 22)