# Wake2Vec Morpheme Expansion Pipeline

This notebook documents a controlled procedure for integrating Joyce-style neologisms into a compact GPT-type language model through morphology-aware token expansion. I curate a small lexicon of prefixes and suffixes and generate synthetic candidates, then extend the tokenizer to admit previously split neologisms as single tokens. New embeddings are initialised by morphemic composition, using the rule \(E(\text{word}) = \alpha\,E(\text{prefix}) + (1 - 2\alpha)\,E(\text{root}) + \alpha\,E(\text{suffix}) + \varepsilon\), where \(\alpha\) is a fixed weight and \(\varepsilon\) is small Gaussian noise that prevents identical vectors. Training proceeds in two stages: an embedding-only warm-up on a mixture of synthetic lines and Finnegans *Wake* text, followed by a short full-model fine-tune under conservative schedules suitable for a T4 environment.

 I report top-five neighbor overlap for the newly introduced tokens before and after training, track shifts in embedding norms, provide a t-SNE projection of the new tokens against pre-training neighbor centroids, and save JSON snapshots of neighborhoods at each stage. These diagnostics are intended to show coherent integration of the new forms into the embedding space rather than collapse or runaway drift, and to make the procedure straightforward to reproduce on modest hardware.

**Config**

Base model: `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`. Composition weight \(\alpha = 0.25\). Maximum sequence length set to 1024 to respect T4 memory limits. Batching uses `per_device_train_batch_size = 1` with `gradient_accumulation_steps = 8`, attention implementation set to `eager`, and `use_cache = False`. Phase 1 trains input embeddings and the tied output head only; Phase 2 unfreezes all parameters with a warm-up ratio of 0.10 and light weight decay. All runs write plots and machine-readable artifacts to `runs/<RUN_ID>/` and generate a brief HTML report.

---

## Run controls
- **BASE_MODEL:** `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`
- **α (composition weight):** `0.25` (can tune)
- **Max seq length:** `1024` (T4-safe; raise only if VRAM allows)
- **Batching:** `per_device_train_batch_size=1`, `gradient_accumulation_steps=8`
- **Attn impl:** `eager` (avoid SDPA spikes on T4)
- **Two phases:**
  - **Phase 1:** embeddings + lm_head only, Adafactor/8-bit Adam, 1 epoch
  - **Phase 2:** full model, short run, warmup 0.10

## Inputs
- `data/FW_TEXT.txt` — Finnegans Wake plain text (slice for demo)
- `data/morpheme_data.json` or `data/morphemes.csv`  
  Structure maps:
  - `prefixes`: `{ prefix → [example words…] }`
  - `suffixes`: `{ suffix → [example words…] }`

## Outputs (per run)
- `runs/<RUN_ID>/metrics/`
  - `pre_morpheme_snapshot.json`
  - `morpheme_comparison_p1.json` *(midpoint, after Phase 1)*
  - `morpheme_comparison.json` *(final, after Phase 2)*
  - `summary_stats_p1.json`, `summary_stats.json`
- `runs/<RUN_ID>/plots/`
  - `hist_overlap_top5(_p1).png`, `hist_norm_change(_p1).png`
  - `scatter_norm_vs_overlap.png`, `tsne_newtokens_vs_precentroids.png`
- `reports/Wake2Vec_Report.html`

## Quickstart
1. **Reset & install** deps (Colab-friendly).  
2. **Load data** (prefers JSON).  
3. **Generate** synthetic forms (prefix + root + suffix).  
4. **Expand tokenizer** (add new tokens); compose embeddings with α-rule; tie head.  
5. **Phase 1**: train embeddings only. Saves midpoint snapshot.
6. **Phase 2**: unfreeze and short fine-tune.  
7. **Diagnostics**: compute overlap@5, norm deltas, t-SNE; write HTML report.  


## Diagnostics (what “good” looks like)
- **Top-5 neighbor overlap (pre→post):** ~3–4/5 indicates coherent integration (not collapse).
- **Norm shift (Δ‖E‖):** small positive mean (slight energy increase from training).
- **Qualitative neighbors:** morpheme-aligned (e.g., `presounder` ≈ `resound`, `ensounder`, …).
- **Tokenization:** most synthetic forms now **single IDs**.

## Repro & env
- `RUN_ID = "t4_<unix>"` auto-stamped; seeds fixed at 42.
- Tested on Colab T4 with: `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.
- T4 guardrails: `MAX_LEN=1024`, `gradient_checkpointing=True`, attention=`eager`, batch=1 + accum=8.

## Troubleshooting (T4)
- **CUDA OOM** → lower `MAX_LEN` to 768/512; keep batch=1; accum=8–16; ensure `use_cache=False`; `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`.
- **Version noise** → uninstall RAPIDS/TF; pin `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.

---

 *Wake2Vec tests morphology-aware token expansion to integrate Joyce-style neologisms into a small language model without destabilising the embedding space. We curate a prefix/suffix lexicon, generate synthetic forms, initialise new vectors by morpheme composition, and train in two phases. Evaluation reports neighbor-overlap@5, embedding-norm shifts, and qualitative neighborhoods, with JSON snapshots for reproducibility.*


In [None]:
!pip -q install --no-cache-dir --upgrade-strategy eager \
  "transformers==4.57.1" "datasets==2.21.0" "accelerate==1.0.1" \
  "peft==0.12.0" "bitsandbytes==0.43.3" \
  "huggingface-hub>=0.34,<1.0" \
  "pyarrow==22.0.0" "numpy==2.0.2" "pandas==2.2.2" "requests==2.32.4" \
  "matplotlib>=3.8" "scikit-learn>=1.5" "umap-learn" "faiss-cpu" "wordfreq" "Unidecode"
import os; os.kill(os.getpid(), 9)  # rr

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/527.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m287.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/330.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m330.9/330.9 kB[0m [31m279.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m292.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m247.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m268.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m244.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Imports, seeds, run IDs, paths

In [None]:
import numpy as np, torch, transformers, datasets, pyarrow as pa, json, time, random, gc, os
from pathlib import Path
from google.colab import drive

print("Transformers:", transformers.__version__)
print("Datasets    :", datasets.__version__)
print("PyArrow     :", pa.__version__)
print("Torch       :", torch.__version__)
print("CUDA        :", torch.version.cuda)
drive.mount('/content/drive', force_remount=True)

# Pps
PROJECT     = "wake2vec"
BASE_MODEL  = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
RUN_ID      = f"t4_{int(time.time())}"

ROOT        = Path("/content")
PERSIST     = Path("/content/drive/MyDrive")/PROJECT
RUN_DIR     = ROOT/"runs"/RUN_ID
METRICS_DIR = RUN_DIR/"metrics"; PLOTS_DIR = RUN_DIR/"plots"; REPORTS_DIR = ROOT/"reports"
ADAPT_DIR   = RUN_DIR/"phase2_lora"/"final_adapters"; TOK_SAVE = ADAPT_DIR  # keep tokenizer here too

for d in [RUN_DIR, METRICS_DIR, PLOTS_DIR, REPORTS_DIR, ADAPT_DIR, PERSIST/"runs", PERSIST/"adapters", PERSIST/"reports", PERSIST/"archives", PERSIST/"notebooks"]:
    d.mkdir(parents=True, exist_ok=True)

# clean & seeds
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(42)

print("RUN_ID:", RUN_ID)

Transformers: 4.57.1
Datasets    : 2.21.0
PyArrow     : 22.0.0
Torch       : 2.8.0+cu126
CUDA        : 12.6
Mounted at /content/drive
RUN_ID: t4_1762113417


Load data

In [None]:
import pandas as pd, json, re
from collections import defaultdict
from pathlib import Path

MORPH_CSV = Path("/content/morphemes.csv")
assert MORPH_CSV.exists(), f"Not found: {MORPH_CSV}"

df = pd.read_csv(MORPH_CSV, dtype=str, keep_default_na=False)
# normalize cols
df.columns = [c.strip().lower() for c in df.columns]

required = {"type","morpheme"}
missing = required - set(df.columns)
if missing:
    raise ValueError(f"CSV is missing columns: {missing}. Expected at least: {required} plus example1..exampleN")

# collect example cols
ex_cols = [c for c in df.columns if c.startswith("example")]
ex_cols.sort(key=lambda s: (len(s), s))
morph = {"prefixes": defaultdict(list), "suffixes": defaultdict(list)}
skipped = 0

for _, r in df.iterrows():
    kind = r["type"].strip().lower()
    piece = r["morpheme"].strip()
    if kind not in ("prefix","suffix") or not piece:
        skipped += 1
        continue
    examples = []
    for c in ex_cols:
        val = str(r[c]).strip()
        if val and val.lower() != "nan":
            examples.append(val)
    if kind == "prefix":
        morph["prefixes"][piece].extend(examples)
    else:
        morph["suffixes"][piece].extend(examples)

# dedupe + sort
for d in (morph["prefixes"], morph["suffixes"]):
    for k in list(d.keys()):
        d[k] = sorted(set([w for w in d[k] if w]))

prefixes = list(morph["prefixes"].keys())
suffixes = list(morph["suffixes"].keys())

print(f"[morph] prefixes: {len(prefixes)} | suffixes: {len(suffixes)}")
print(f"[morph] prefix examples total: {sum(len(v) for v in morph['prefixes'].values())} | "
      f"suffix examples total: {sum(len(v) for v in morph['suffixes'].values())} | skipped rows: {skipped}")

out_dir = (PERSIST/"runs"/RUN_ID)
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir/"morpheme_data.json").write_text(json.dumps(morph, indent=2), encoding="utf-8")

# expose variables used downstream
print("Sample prefixes:", prefixes[:5])
print("Sample suffixes:", suffixes[:5])

[morph] prefixes: 15 | suffixes: 15
[morph] prefix examples total: 150 | suffix examples total: 150 | skipped rows: 0
Sample prefixes: ['all', 'be', 'con', 'de', 'en']
Sample suffixes: ['ure', 'y', 'ty', 'th', 'ster']


Tok expansion + composed init

In [None]:
import re, json, math, numpy as np, torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

assert 'morph' in globals() and 'prefixes' in globals() and 'suffixes' in globals(), "Run the morpheme loader first."

# synthetic_lines
if 'synthetic_lines' not in globals() or not synthetic_lines:
    import random
    ROOTS = ["river","thunder","word","sound","dance","queen","storm","tree","night","sun","rain","book"]
    random.seed(13)
    synthetic_lines = []
    for _ in range(600):
        p = random.choice(prefixes) if prefixes else ""
        r = random.choice(ROOTS)
        s = random.choice(suffixes) if suffixes else ""
        synthetic_lines.append(f"The {p+r+s} rose and fell as if the {r} had learned to {s or 'sing'}.\n")
    (PERSIST/"runs"/RUN_ID).mkdir(parents=True, exist_ok=True)
    (PERSIST/"runs"/RUN_ID/"synthetic_lines.txt").write_text("".join(synthetic_lines), encoding="utf-8")

# tokenizer
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# synthetic + raw morph pieces
new_tokens = set()
for line in synthetic_lines:
    for m in re.finditer(r"\b([a-zA-Z][a-zA-Z\-]{2,})\b", line):
        new_tokens.add(m.group(1).lower())
for p in prefixes: new_tokens.add(p)
for s in suffixes: new_tokens.add(s)

# filter unknowns only
new_tokens = [t for t in sorted(new_tokens) if tok.convert_tokens_to_ids(t) == tok.unk_token_id]

TOK_SAVE = (RUN_DIR/"phase2_lora"/"final_adapters")
TOK_SAVE.mkdir(parents=True, exist_ok=True)
added = tok.add_tokens(new_tokens, special_tokens=False)
tok.save_pretrained(str(TOK_SAVE))
json.dump(new_tokens, open(METRICS_DIR/"new_tokens.json","w"))
print(f"[tokenizer] added {added} tokens → saved at {TOK_SAVE}")

# load model, resize without mean-resizing, tie head
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="auto")
old_size = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok), mean_resizing=False)
emb = model.get_input_embeddings().weight
device, dtype = emb.device, emb.dtype

# helper: only return real token embeddings
def emb_for_token(t: str):
    tid = tok.convert_tokens_to_ids(t)
    if tid >= 0 and tid != tok.unk_token_id:
        return emb[tid].detach().clone()
    return None

def mean_emb_for_words(words):
    vecs = []
    for w in words:
        tid = tok.convert_tokens_to_ids(w)
        if tid >= 0 and tid != tok.unk_token_id:
            vecs.append(emb[tid].detach().clone())
    if vecs:
        return torch.stack(vecs, dim=0).mean(dim=0)
    # fallback: small random around global std
    std = emb.detach().std().item()
    return torch.randn(emb.shape[1], device=device, dtype=dtype) * (0.1 * std)

alpha = 0.25
E_comp = []
new_ids = []

# prebuild sorted lists for greedy matches
_pref_sorted = sorted(prefixes, key=len, reverse=True)
_suf_sorted  = sorted(suffixes, key=len, reverse=True)

for t in new_tokens:
    tid = tok.convert_tokens_to_ids(t)
    if tid < old_size:
        continue

    # greedy longest prefix/suffix match
    p = next((pp for pp in _pref_sorted if t.startswith(pp)), "")
    s = next((ss for ss in _suf_sorted  if t.endswith(ss)), "")
    core = t[len(p):len(t)-len(s) if s else None]

    Ep = mean_emb_for_words(morph["prefixes"].get(p, [p])) if p else torch.zeros(emb.shape[1], device=device, dtype=dtype)
    tmp = emb_for_token(core)
    Er  = tmp if tmp is not None else mean_emb_for_words([core])
    Es = mean_emb_for_words(morph["suffixes"].get(s, [s])) if s else torch.zeros(emb.shape[1], device=device, dtype=dtype)

    comp = alpha*Ep + (1 - 2*alpha)*Er + alpha*Es
    comp = comp + 0.01 * emb.detach().std().item() * torch.randn_like(comp)

    with torch.no_grad():
        emb[tid] = comp
    new_ids.append(tid)
    E_comp.append(comp.detach().cpu().numpy())

new_ids = torch.tensor(new_ids, dtype=torch.long, device=device)
np.save(METRICS_DIR/"E_comp_newtokens.npy", np.stack(E_comp))
print(f"[init] wrote composed embeddings for {len(new_ids)} new rows. Vocab {old_size} → {len(tok)}")

# tie output head
model.lm_head.weight = model.get_input_embeddings().weight
model.config.pad_token_id = tok.pad_token_id
model.config.use_cache = False
model.config._attn_implementation = "eager"

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

[tokenizer] added 534 tokens → saved at /content/runs/t4_1762113417/phase2_lora/final_adapters


config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

[init] wrote composed embeddings for 534 new rows. Vocab 32000 → 32534


build dataset (FW + Synthetic)

In [None]:
assert tok.pad_token_id is not None
assert model.lm_head.weight.data_ptr() == model.get_input_embeddings().weight.data_ptr()
print("new tokens:", len(new_ids), "| vocab size:", len(tok))

new tokens: 534 | vocab size: 32534


# P1 Embeddings-only warm-up

In [None]:
!pip -q uninstall -y accelerate || true
!pip -q install --no-cache-dir "accelerate==1.2.1"
import os; os.kill(os.getpid(), 9)  # rr

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/336.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.4/336.4 kB[0m [31m112.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
FW_TEXT = Path("/content/FW_TEXT.txt")
assert FW_TEXT.exists(), f"Not found: {FW_TEXT}"
fw_text = FW_TEXT.read_text(encoding="utf-8")
print(f"Loaded {len(fw_text)} chars from {FW_TEXT}")

Loaded 1364712 chars from /content/FW_TEXT.txt


In [None]:
from datasets import Dataset

# FW blocks
MAX_LEN = 384
def chunk_text(txt, max_chars=1200):
    parts = []
    buf = []
    n = 0
    for line in txt.splitlines():
        if not line.strip(): continue
        buf.append(line)
        n += len(line)+1
        if n >= max_chars:
            parts.append("\n".join(buf)+"\n")
            buf, n = [], 0
    if buf: parts.append("\n".join(buf)+"\n")
    return parts

fw_blocks = chunk_text(fw_text) if fw_text else []
mix = fw_blocks + synthetic_lines
random.shuffle(mix)

def to_ids(s):
    return tok.encode(s, add_special_tokens=False)[:MAX_LEN]

enc = [{"input_ids": to_ids(s)} for s in mix if to_ids(s)]
# small split
split = int(0.9*len(enc))
train_ds = Dataset.from_list(enc[:split])
valid_ds = Dataset.from_list(enc[split:]) if split < len(enc) else Dataset.from_list(enc[:100])

def dc(features):
    # pad on right with eos/pad
    import torch
    maxlen = max(len(x["input_ids"]) for x in features)
    input_ids = []
    labels = []
    for f in features:
        ids = f["input_ids"]
        pad = [tok.pad_token_id]*(maxlen-len(ids))
        input_ids.append(ids+pad)
        labels.append(ids+pad)
    return {"input_ids": torch.tensor(input_ids), "labels": torch.tensor(labels)}

print("train blocks:", len(train_ds), "valid:", len(valid_ds))

train blocks: 1538 valid: 171


In [None]:
import inspect, accelerate, transformers
from accelerate import Accelerator
print("Transformers:", transformers.__version__)
print("Accelerate  :", accelerate.__version__, accelerate.__file__)
print("unwrap_model sig:", inspect.signature(Accelerator.unwrap_model))

Transformers: 4.57.1
Accelerate  : 1.0.1 /usr/local/lib/python3.12/dist-packages/accelerate/__init__.py
unwrap_model sig: (self, model, keep_fp32_wrapper: 'bool' = True)


In [None]:
import inspect
from accelerate import Accelerator

_orig_unwrap = Accelerator.unwrap_model

def _unwrap_model_compat(self, model, *args, **kwargs):
    # Transformers>=4.56 may pass keep_torch_compile; older Accelerate doesn't accept it.
    kwargs.pop("keep_torch_compile", None)
    # Map/ensure keep_fp32_wrapper with sensible default
    keep_fp32_wrapper = kwargs.pop("keep_fp32_wrapper", True)
    # If caller passed it positionally, respect that
    if args:
        # first positional after model would be keep_fp32_wrapper
        keep_fp32_wrapper = bool(args[0])
    return _orig_unwrap(self, model, keep_fp32_wrapper)

Accelerator.unwrap_model = _unwrap_model_compat
print("Patched accelerate.Accelerator.unwrap_model for keep_torch_compile compatibility.")
print("new sig (logical): (self, model, keep_fp32_wrapper: bool = True)")

Patched accelerate.Accelerator.unwrap_model for keep_torch_compile compatibility.
new sig (logical): (self, model, keep_fp32_wrapper: bool = True)


In [None]:
from transformers import Trainer, TrainingArguments
import numpy as np, torch, json
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path

# freeze all except embeddings + head
for p in model.parameters(): p.requires_grad = False
E = model.get_input_embeddings().weight; E.requires_grad_(True)
for p in model.lm_head.parameters(): p.requires_grad = True

from transformers import TrainingArguments
args1 = TrainingArguments(
    output_dir=str(RUN_DIR/"phase1"),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=2000,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    logging_strategy="steps", logging_steps=50,
    save_strategy="no",
    eval_strategy="no",
    gradient_checkpointing=True,
    fp16=False,
    report_to="none",
    optim="adafactor",          # ← use this
)

trainer1 = Trainer(model=model, args=args1, train_dataset=train_ds, data_collator=dc)
out1 = trainer1.train()
print(out1)

# snapshot vs composed init
with torch.no_grad():
    W1 = model.get_input_embeddings().weight.detach().cpu().numpy()
    sim1 = cosine_similarity(W1[new_ids.cpu()], W1)
    top5_1 = np.argsort(-sim1, axis=1)[:,1:6]

E_comp_np = np.load(METRICS_DIR/"E_comp_newtokens.npy")
sim0 = cosine_similarity(E_comp_np, W1)
top5_0 = np.argsort(-sim0, axis=1)[:,1:6]

def overlap5(a,b): return len(set(a.tolist()) & set(b.tolist()))
overlaps1 = np.array([overlap5(top5_0[i], top5_1[i]) for i in range(len(new_ids))])
pre_norms = np.linalg.norm(W1[new_ids.cpu()], axis=1)

(Path(METRICS_DIR/"summary_stats_p1.json")).write_text(
    json.dumps({"phase":"phase1","compared_tokens":int(len(new_ids)),
                "mean_top5_overlap":float(overlaps1.mean()),"mean_norm_delta":0.0}, indent=2)
)
(Path(METRICS_DIR/"morpheme_comparison_p1.json")).write_text(
    json.dumps({"top5_pre":top5_0.tolist(),"top5_p1":top5_1.tolist(),
                "overlap@5":overlaps1.tolist()}, indent=2)
)
np.save(METRICS_DIR/"pre_norms.npy", pre_norms)
print("P1 mean overlap@5:", overlaps1.mean())

The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
50,6.3321
100,5.3361
150,4.3737
200,4.2891
250,3.4766
300,3.4358
350,2.8894
400,2.8104
450,2.4056
500,2.3953


In [None]:
import os, json, time, pathlib, numpy as np, torch
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,
    DataCollatorForLanguageModeling, TrainerCallback
)

# set up
DRIVE_ROOT = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN_ID = f"t4_{int(time.time())}"
(DRIVE_ROOT/"runs"/RUN_ID).mkdir(parents=True, exist_ok=True)

RUN_DIR = pathlib.Path("/content/runs")/RUN_ID
if not RUN_DIR.exists():
    RUN_DIR.parent.mkdir(parents=True, exist_ok=True)
    os.symlink(str(DRIVE_ROOT/"runs"/RUN_ID), str(RUN_DIR))

METRICS_DIR = RUN_DIR/"metrics"; METRICS_DIR.mkdir(parents=True, exist_ok=True)

# callbacks
class LossStreamer(TrainerCallback):
    def __init__(self, log_every=50, window=200, out_json=None):
        self.log_every, self.window, self.out_json = log_every, window, out_json
        self.buf, self.recs = [], []
    def on_log(self, args, state, control, logs=None, **kw):
        if not logs or "loss" not in logs: return
        s = int(state.global_step or 0); L = float(logs["loss"])
        self.buf.append(L); self.recs.append({"step": s, "loss": L, "lr": logs.get("learning_rate")})
        if s % self.log_every == 0 and s > 0:
            w = self.buf[-self.window:]; ma = sum(w)/len(w)
            print(f"[P1 {s}] train_loss={L:.4f}  ma({len(w)})={ma:.4f}")
            if self.out_json: open(self.out_json,"w").write(json.dumps(self.recs, indent=2))

loss_cb = LossStreamer(out_json=str(METRICS_DIR/"phase1_loss_log.json"))

# Snapshot new rows every 200 steps
NEW_IDS_PATH = DRIVE_ROOT/"new_ids.npy"
new_ids = np.load(NEW_IDS_PATH) if NEW_IDS_PATH.exists() else None

class EmbedSnapshot(TrainerCallback):
    def __init__(self, run_dir, new_ids, every=200):
        self.run_dir, self.new_ids, self.every = pathlib.Path(run_dir), new_ids, every
    def on_step_end(self, args, state, control, **kw):
        if self.new_ids is None: return
        s = int(state.global_step or 0)
        if s>0 and s % self.every == 0:
            m = kw.get("model");
            if m is None: return
            with torch.no_grad():
                E = m.get_input_embeddings().weight.detach().cpu().numpy()
            np.save(self.run_dir/"metrics"/f"E_postP1_step{s}.npy", E[self.new_ids])
            print(f"[SNAP] new-row embeddings @ {s}")

snap_cb = EmbedSnapshot(RUN_DIR, new_ids, every=200)

#model/ tok
BASE = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight

train_ds = load_from_disk(str(DRIVE_ROOT/"datasets"/"train_ds"))
valid_ds = load_from_disk(str(DRIVE_ROOT/"datasets"/"valid_ds"))

# use a tiny shard for quick evals
valid_ds_small = valid_ds.select(range(min(1000, len(valid_ds))))
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

# training args
args = TrainingArguments(
    output_dir=str(RUN_DIR),
    seed=42,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=1100,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=50,
    save_steps=100,
    save_total_limit=12,
    evaluation_strategy="steps",
    eval_steps=200,
    gradient_checkpointing=True,
    fp16=False, bf16=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds_small,
    data_collator=collator,
    callbacks=[loss_cb, snap_cb],
)

open(RUN_DIR/"run_manifest.json","w").write(json.dumps({
    "run_id": RUN_ID, "base": BASE, "started": time.time(),
    "max_steps": 1100, "grad_accum": 16, "optim": "adafactor",
    "eval_strategy": "steps", "eval_steps": 200, "valid_shard": len(valid_ds_small)
}, indent=2))

trainer.train()
trainer.save_model(str(RUN_DIR/"checkpoint-final"))
tok.save_pretrained(str(RUN_DIR/"checkpoint-final"))
print("[DONE] P1 complete →", RUN_DIR/"checkpoint-final")

In [None]:
import os, glob, json, pathlib, time, re
from google.colab import drive

# Mount Drive (idempotent)
try:
    drive.mount('/content/drive')
except Exception:
    pass

from pathlib import Path
PERSIST = Path("/content/drive/MyDrive/wake2vec")
PERSIST.mkdir(parents=True, exist_ok=True)

# Prefer last local run, else last Drive run
local_runs = sorted(Path("/content/runs").glob("*"), key=lambda p: p.stat().st_mtime) if Path("/content/runs").exists() else []
drive_runs = sorted((PERSIST / "runs").glob("*"), key=lambda p: p.stat().st_mtime) if (PERSIST / "runs").exists() else []

if local_runs:
    RUN_DIR = local_runs[-1]
    RUN_ID = RUN_DIR.name
    SOURCE = "local"
else:
    if drive_runs:
        RUN_DIR = drive_runs[-1]
        RUN_ID = RUN_DIR.name
        SOURCE = "drive"
    else:
        raise SystemExit("No prior run dirs found in /content/runs or Drive /wake2vec/runs.")

print(f"[INFO] Using RUN_ID={RUN_ID} from {SOURCE}: {RUN_DIR}")

(PERSIST / "runs" / RUN_ID).mkdir(parents=True, exist_ok=True)
if SOURCE == "local":
    os.system(f'rsync -a --ignore-existing "{RUN_DIR}/" "{PERSIST}/runs/{RUN_ID}/"')
else:
    # create local symlink for fast IO but persistent storage
    Path("/content/runs").mkdir(parents=True, exist_ok=True)
    target = Path("/content/runs") / RUN_ID
    if not target.exists():
        os.symlink(str(RUN_DIR), str(target))
    RUN_DIR = target
    print(f"[INFO] Symlinked drive run → {RUN_DIR}")

# Create common dirs
METRICS_DIR = RUN_DIR / "metrics"
METRICS_DIR.mkdir(parents=True, exist_ok=True)

# latest checkpoint dir
ckpts = sorted(glob.glob(str(RUN_DIR / "checkpoint-*")), key=lambda p: int(re.findall(r"checkpoint-(\d+)", p)[0]) if re.findall(r"checkpoint-(\d+)", p) else -1)
CKPT = ckpts[-1] if ckpts else None
print(f"[INFO] Latest checkpoint: {CKPT if CKPT else 'NONE'}")

manifest = {
    "run_id": RUN_ID,
    "resumed_at": time.time(),
    "source": SOURCE,
    "latest_ckpt": CKPT,
}
(RUN_DIR / "resume_manifest.json").write_text(json.dumps(manifest, indent=2))
print("[OK] resume_manifest.json written.")

Mounted at /content/drive
[INFO] Using RUN_ID=t4_1762113417 from drive: /content/drive/MyDrive/wake2vec/runs/t4_1762113417
[INFO] Symlinked drive run → /content/runs/t4_1762113417
[INFO] Latest checkpoint: NONE
[OK] resume_manifest.json written.


In [None]:
# P1 FINALIZE: compute overlap@5 (composed-init vs post-P1), norm stats, plots
import os, json, math, glob, pathlib, time, shutil
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

# run meta
ROOT = pathlib.Path("/content")
RUNS = ROOT / "runs"
RUN_ID = sorted([p.name for p in RUNS.glob("*")])[-1]  # last run
RUN_DIR = RUNS / RUN_ID
METRICS_DIR = RUN_DIR / "metrics"
PLOTS_DIR = RUN_DIR / "plots"
PERSIST = pathlib.Path("/content/drive/MyDrive/wake2vec")
ADAPT_DIR = RUN_DIR / "phase2_lora" / "final_adapters"
for d in (METRICS_DIR, PLOTS_DIR, ADAPT_DIR, PERSIST): d.mkdir(parents=True, exist_ok=True)

manifest_path = RUN_DIR / "run_manifest.json"
if manifest_path.exists():
    manifest = json.loads(manifest_path.read_text())
else:
    manifest = {"run_id": RUN_ID, "time": time.time()}
    manifest_path.write_text(json.dumps(manifest, indent=2))

# Locate latest checkpoint or use final model in memory or use a saved checkpoint: trainer usually dumps 'checkpoint-<step>'
ckpts = sorted(glob.glob(str(RUN_DIR / "checkpoint-*")), key=lambda p: int(p.rsplit("-",1)[-1]))
ckpt_dir = ckpts[-1] if ckpts else str(RUN_DIR)
base_model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

print(f"[INFO] Loading model/tokenizer from: {ckpt_dir}")
tok = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight

emb = model.get_input_embeddings().weight.detach().cpu().numpy()  # [V, d]
V, d = emb.shape
print(f"[INFO] vocab={V}, dim={d}")

# Load new_ids and composed init vectors (optional)
new_ids_path = PERSIST / "new_ids.npy"
E_comp_path = PERSIST / "E_comp.npy"            # composed init vectors aligned to new_ids

new_ids = np.load(new_ids_path) if new_ids_path.exists() else None
E_comp  = np.load(E_comp_path)  if E_comp_path.exists()  else None
if new_ids is not None:
    print(f"[INFO] new_ids loaded: {new_ids.shape[0]}")
else:
    print("[WARN] new_ids.npy not found; will compute only global stats/plots.")

# top-k neighbors by cosine
def topk_neighbors(vecs, mat, k=5, mask_self=None):
    # vecs: [m, d], mat: [V, d]
    # returns indices [m, k]
    va = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9)
    ma = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-9)
    sims = va @ ma.T  # [m, V]
    if mask_self is not None:
        sims[np.arange(sims.shape[0])[:,None], mask_self[:,None]] = -1e9
    nbrs = np.argpartition(-sims, kth=np.arange(5), axis=1)[:, :k]
    # sort each row’s top-k
    part_vals = np.take_along_axis(sims, nbrs, axis=1)
    order = np.argsort(-part_vals, axis=1)
    return np.take_along_axis(nbrs, order, axis=1)

# Overlap@5: composed-init vs post-P1
overlap5 = None
if (new_ids is not None) and (E_comp is not None) and (E_comp.shape[0] == new_ids.shape[0]):
    # neighbors of composed vectors vs neighbors of current learned new embeddings
    E_new_post = emb[new_ids]                         # [n_new, d] current rows after P1
    k = 5
    nbr_comp = topk_neighbors(E_comp, emb, k=k)      # composed vec, top-k in current vocab
    nbr_post = topk_neighbors(E_new_post, emb, k=k)  # learned row, top-k in current vocab

    # overlap fraction per token
    ov = []
    for a, b in zip(nbr_comp, nbr_post):
        ov.append(len(set(a.tolist()) & set(b.tolist())) / k)
    overlap5 = np.array(ov)
    np.save(METRICS_DIR / "p1_overlap_at5.npy", overlap5)
    print(f"[OK] P1 overlap@5: mean={overlap5.mean():.4f} | median={np.median(overlap5):.4f}")
else:
    print("[WARN] Skipping overlap@5 (missing E_comp or new_ids).")

# Norm stats
norms = np.linalg.norm(emb, axis=1)
np.save(METRICS_DIR / "postP1_norms.npy", norms)
if new_ids is not None:
    norms_new = norms[new_ids]
    np.save(METRICS_DIR / "postP1_norms_new.npy", norms_new)

# Plots
plt.figure(figsize=(7,4.5))
plt.hist(norms, bins=60)
plt.title("Post-P1 embedding norms (all vocab)")
plt.xlabel("‖E‖"); plt.ylabel("count")
plt.tight_layout()
plt.savefig(PLOTS_DIR / "hist_postP1_norms.png", dpi=160)

if (overlap5 is not None):
    plt.figure(figsize=(7,4.5))
    plt.hist(overlap5, bins=np.linspace(0,1,11), align="left", rwidth=0.9)
    plt.xticks(np.linspace(0,1,11))
    plt.title("P1 overlap@5 (composed-init vs post-P1)")
    plt.xlabel("overlap@5"); plt.ylabel("count")
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / "hist_overlap_top5_P1.png", dpi=160)

# Persist quick stats
stats = {
    "run_id": RUN_ID,
    "vocab": int(V),
    "dim": int(d),
    "n_new": int(new_ids.shape[0]) if new_ids is not None else None,
    "p1_overlap5_mean": float(overlap5.mean()) if overlap5 is not None else None,
    "p1_overlap5_median": float(np.median(overlap5)) if overlap5 is not None else None,
    "timestamp": time.time(),
}
json.dump(stats, open(METRICS_DIR / "p1_stats.json","w"), indent=2)
print("[DONE] P1 metrics saved →", METRICS_DIR)

IndexError: list index out of range

In [None]:
import shutil, tarfile

def snapshot_to_drive(tag):
    # copy artifacts
    d_run = PERSIST/'runs'/RUN_ID/tag
    d_run.mkdir(parents=True, exist_ok=True)
    if METRICS_DIR.exists(): shutil.copytree(METRICS_DIR, d_run/'metrics', dirs_exist_ok=True)
    if PLOTS_DIR.exists():   shutil.copytree(PLOTS_DIR,   d_run/'plots',   dirs_exist_ok=True)
    if (REPORTS_DIR/'Wake2Vec_Report.html').exists():
        (PERSIST/'reports').mkdir(parents=True, exist_ok=True)
        shutil.copy(REPORTS_DIR/'Wake2Vec_Report.html', PERSIST/'reports'/f'Wake2Vec_Report_{RUN_ID}_{tag}.html')
    # adapters/tokenizer
    src_ad = PERSIST/'adapters'/RUN_ID/'final_adapters'
    if src_ad.exists():
        shutil.copytree(src_ad, PERSIST/'adapters'/RUN_ID/'final_adapters', dirs_exist_ok=True)
    # tarball of the local run folder for belt+braces
    (PERSIST/'archives').mkdir(parents=True, exist_ok=True)
    with tarfile.open(PERSIST/'archives'/f"{RUN_ID}_{tag}.tar.gz", "w:gz") as tar:
        tar.add(str(RUN_DIR), arcname=f"runs/{RUN_ID}")
    print(f"[snapshot] saved → Drive under runs/{RUN_ID}/{tag} and archives/")

In [None]:
snapshot_to_drive("phase1")
snapshot_to_drive("phase2")
snapshot_to_drive("phase3")

# P2 LoRA boost

In [None]:
# --- Phase-2 (LoRA) — your LossStreamer kept, with safe fixes ---
import os, json, pathlib, time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback
from datasets import load_from_disk

RUNS = pathlib.Path("/content/runs")
RUN_ID = sorted([p.name for p in RUNS.glob("*")])[-1]
RUN_DIR = RUNS / RUN_ID
METRICS_DIR = RUN_DIR / "metrics"
ADAPT_DIR = RUN_DIR / "phase2_lora" / "final_adapters"
PERSIST = pathlib.Path("/content/drive/MyDrive/wake2vec")
for d in (METRICS_DIR, ADAPT_DIR, PERSIST): d.mkdir(parents=True, exist_ok=True)

# -- Loss streamer (yours) --
class LossStreamer(TrainerCallback):
    def __init__(self, log_every=50, window=200, out_json=None):
        self.log_every = log_every
        self.window = window
        self.buf = []
        self.recs = []
        self.out_json = out_json

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs or "loss" not in logs:
            return
        step = int(state.global_step or 0)
        loss = float(logs["loss"])
        self.buf.append(loss)
        self.recs.append({"step": step, "loss": loss, "lr": logs.get("learning_rate", None)})
        if step % self.log_every == 0 and step > 0:
            w = self.buf[-self.window:]
            ma = sum(w)/len(w)
            lr = logs.get("learning_rate", None)
            print(f"[P2 step {step}] loss={loss:.4f}  ma({len(w)})={ma:.4f}" + (f"  lr={lr:.2e}" if lr else ""))
            if self.out_json:
                with open(self.out_json, "w") as f: json.dump(self.recs, f, indent=2)

loss_cb2 = LossStreamer(log_every=50, window=200, out_json=str(METRICS_DIR/"phase2_loss_log.json"))

# -- Safe imports for peft/bnb --
def try_import_peft():
    try:
        import peft
        return peft
    except Exception:
        return None

peft = try_import_peft()
if peft is None:
    print("[WARN] peft not available — will train base model (no adapters) and use Adafactor.")

# -- Load latest checkpoint as P2 start --
ckpts = sorted([p for p in RUN_DIR.glob("checkpoint-*")], key=lambda p: int(p.name.split("-")[-1]))
ckpt_dir = str(ckpts[-1] if ckpts else RUN_DIR)
print(f"[INFO] P2 loading from: {ckpt_dir}")

tok = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(ckpt_dir, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight  # tie

# -- Datasets (expects saved by P1) --
train_path = RUN_DIR / "train_ds"
valid_path = RUN_DIR / "valid_ds"
train_ds = load_from_disk(str(train_path)) if train_path.exists() else None
valid_ds = load_from_disk(str(valid_path)) if valid_path.exists() else None
assert (train_ds is not None) and (valid_ds is not None), "Tokenized datasets not found in RUN_DIR."

dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

# -- Attach LoRA if available --
use_optim = "adamw_bnb_8bit"
if peft is not None:
    from peft import LoraConfig, get_peft_model, TaskType
    lcfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        bias="none"
    )
    model = get_peft_model(model, lcfg)
    try:
        import bitsandbytes as bnb  # noqa: F401
    except Exception:
        print("[WARN] bitsandbytes missing — falling back to Adafactor.")
        use_optim = "adafactor"
else:
    use_optim = "adafactor"

# -- TrainingArguments (fixed arg names; T4-safe) --
args2 = TrainingArguments(
    output_dir=str(RUN_DIR/"phase2_lora"),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    learning_rate=2e-5,
    warmup_ratio=0.10,
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="epoch",
    evaluation_strategy="epoch",  # you had eval off; light epoch eval is fine
    gradient_checkpointing=True,
    fp16=False, bf16=False,
    report_to=["none"],
    optim=use_optim,
    max_grad_norm=1.0,
)

trainer2 = Trainer(
    model=model,
    args=args2,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
    callbacks=[loss_cb2]
)

out2 = trainer2.train()
print(out2)

# -- Save adapters (or full model) + tokenizer, mirror to Drive --
if peft is not None and hasattr(model, "save_pretrained"):
    model.save_pretrained(str(ADAPT_DIR), safe_serialization=True)
else:
    # If no peft, persist full weights so you still have a P2 artifact
    (ADAPT_DIR / "full_model").mkdir(parents=True, exist_ok=True)
    model.save_pretrained(str(ADAPT_DIR / "full_model"), safe_serialization=True)

tok.save_pretrained(str(ADAPT_DIR))

drive_target = PERSIST / "adapters" / RUN_ID / "final_adapters"
drive_target.mkdir(parents=True, exist_ok=True)
os.system(f'cp -r "{ADAPT_DIR}"/* "{drive_target}"')
print("[DONE] Phase-2 saved →", ADAPT_DIR, "and mirrored →", drive_target)

# P3 embedding alignment (new rows only, LM CE + anchors)


In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn

with torch.no_grad():
    W2 = model.get_input_embeddings().weight.detach().cpu().numpy()
    sim_pre = cosine_similarity(W2[new_ids.cpu()], W2)
    top5_pre = np.argsort(-sim_pre, axis=1)[:,1:6]  # acts as P2 baseline

# targets
centroids = torch.tensor(W2[top5_pre].mean(axis=1), dtype=torch.float32, device=E.device)
pre_norms = torch.tensor(np.load(METRICS_DIR/"pre_norms.npy")[:len(new_ids)], dtype=torch.float32, device=E.device)
E_comp = torch.tensor(np.load(METRICS_DIR/"E_comp_newtokens.npy")[:len(new_ids)], dtype=torch.float32, device=E.device)

# freeze all but embeddings+head
for p in model.parameters(): p.requires_grad = False
E = model.get_input_embeddings().weight; E.requires_grad_(True)
for p in model.lm_head.parameters(): p.requires_grad = True

LMB_ANCHOR, LMB_CENTROID, LMB_NORM = 1e-3, 1e-3, 5e-4
id_to_row = {int(t.item()): r for r,t in enumerate(new_ids.cpu())}

def batch_rows(input_ids):
    ids = torch.unique(input_ids).tolist()
    rows = [id_to_row[i] for i in ids if i in id_to_row]
    return torch.tensor(rows, dtype=torch.long, device=E.device) if rows else None

from transformers import Trainer
class MorphAlignTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs); loss = outputs.loss
        rows = batch_rows(inputs["input_ids"])
        if rows is not None:
            E_rows = E[new_ids[rows].to(E.device)]
            l_cent = (1 - nn.functional.cosine_similarity(E_rows, centroids[rows], dim=1)).mean()
            l_norm = nn.functional.mse_loss(E_rows.norm(dim=1), pre_norms[rows])
            l_anch = nn.functional.mse_loss(E_rows, E_comp[rows])
            loss = loss + LMB_CENTROID*l_cent + LMB_NORM*l_norm + LMB_ANCHOR*l_anch
        return (loss, outputs) if return_outputs else loss

class NewRowMaskTrainer(MorphAlignTrainer):
    def training_step(self, model, inputs, num_items_in_batch=None):
        out = super().training_step(model, inputs, num_items_in_batch)
        if E.grad is not None:
            mask = torch.zeros_like(E.grad, dtype=torch.bool)
            mask[new_ids.to(E.device)] = True
            E.grad = torch.where(mask, E.grad, torch.zeros_like(E.grad))
        return out

from transformers import TrainerCallback
import torch, math, json, numpy as np

class P3LiveLogger(TrainerCallback):
    def __init__(self, new_ids, E, centroids, pre_norms, E_comp, log_every=50, out_json=METRICS_DIR/"phase3_live_log.json"):
        self.new_ids = new_ids
        self.E = E
        self.centroids = centroids
        self.pre_norms = pre_norms
        self.E_comp = E_comp
        self.log_every = log_every
        self.out_json = out_json
        self.records = []

    @torch.no_grad()
    def _metrics_snapshot(self, step):
        rows = self.new_ids.to(self.E.device)
        E_rows = self.E[rows]
        # terms
        cos_cent = torch.nn.functional.cosine_similarity(E_rows, self.centroids, dim=1)
        l_centroid = (1 - cos_cent).mean().item()
        l_norm = torch.nn.functional.mse_loss(E_rows.norm(dim=1), self.pre_norms).item()
        l_anchor = torch.nn.functional.mse_loss(E_rows, self.E_comp).item()
        # health stats
        mean_norm = E_rows.norm(dim=1).mean().item()
        grad_norm = (self.E.grad[rows].norm().item() if self.E.grad is not None else float("nan"))
        # light overlap probe: compare to centroids’ neighbors proxy in current space
        # (cheap proxy: average cosine with centroids instead of full top-k)
        mean_cos_cent = cos_cent.mean().item()
        self.records.append({
            "step": int(step),
            "l_centroid": l_centroid,
            "l_norm": l_norm,
            "l_anchor": l_anchor,
            "mean_norm": mean_norm,
            "grad_norm": grad_norm,
            "mean_cos_centroid": mean_cos_cent,
        })

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.log_every == 0 and state.global_step > 0:
            self._metrics_snapshot(state.global_step)
        return control

    def on_train_end(self, args, state, control, **kwargs):
        # final snapshot
        self._metrics_snapshot(state.global_step)
        with open(self.out_json, "w") as f:
            json.dump(self.records, f, indent=2)

trainer3 = NewRowMaskTrainer(model=model, args=args3, train_dataset=train_ds, data_collator=dc)
print("Phase 3 — go")
out3 = trainer3.train(); print(out3)

# P3 snapshot
with torch.no_grad():
    W3 = model.get_input_embeddings().weight.detach().cpu().numpy()
    sim3 = cosine_similarity(W3[new_ids.cpu()], W3)
    top5_3 = np.argsort(-sim3, axis=1)[:,1:6]

def overlap5(a,b): return len(set(a.tolist()) & set(b.tolist()))
overlaps3 = np.array([overlap5(top5_pre[i], top5_3[i]) for i in range(len(new_ids))])
norms3 = np.linalg.norm(W3[new_ids.cpu()], axis=1)
norm_deltas3 = norms3 - pre_norms.detach().cpu().numpy()

summary_p3 = {"phase":"phase3", "compared_tokens": int(len(new_ids)),
              "mean_top5_overlap": float(overlaps3.mean()),
              "mean_norm_delta": float(norm_deltas3.mean())}
(Path(METRICS_DIR/"summary_stats_p3.json")).write_text(json.dumps(summary_p3, indent=2))
(Path(METRICS_DIR/"morpheme_comparison_p3.json")).write_text(json.dumps({
    "top5_pre": top5_pre.tolist(), "top5_p3": top5_3.tolist(),
    "overlap@5": overlaps3.tolist(), "norm_deltas": norm_deltas3.tolist()
}, indent=2))
print("P3:", summary_p3)

In [None]:
# tiny plot
import json, matplotlib.pyplot as plt
log = json.loads((METRICS_DIR/"phase3_live_log.json").read_text())
xs = [r["step"] for r in log]
for k in ["l_centroid","l_norm","l_anchor","mean_norm","grad_norm","mean_cos_centroid"]:
    ys = [r[k] for r in log]
    plt.figure(); plt.plot(xs, ys); plt.title(f"P3 {k}"); plt.xlabel("step"); plt.ylabel(k); plt.tight_layout()
    plt.savefig(PLOTS_DIR/f"p3_{k}.png", dpi=160); plt.close()

plots and HTML

In [None]:
import matplotlib.pyplot as plt, json
# choose latest comparison
cmp = json.loads((METRICS_DIR/"morpheme_comparison_p3.json").read_text())
overlaps = np.array(cmp["overlap@5"]); deltas = np.array(cmp["norm_deltas"])

plt.figure(); plt.hist(overlaps, bins=[-0.5,0.5,1.5,2.5,3.5,4.5,5.5]); plt.title("Top-5 neighbor overlap"); plt.xlabel("Overlap"); plt.ylabel("Freq"); plt.tight_layout(); plt.savefig(PLOTS_DIR/"hist_overlap_top5.png", dpi=180); plt.close()
plt.figure(); plt.hist(deltas, bins=30); plt.title("Embedding norm change"); plt.xlabel("Δ‖E‖"); plt.ylabel("Freq"); plt.tight_layout(); plt.savefig(PLOTS_DIR/"hist_norm_change.png", dpi=180); plt.close()
plt.figure(); plt.scatter(deltas, overlaps, alpha=0.6); plt.title("Δ‖E‖ vs Overlap@5"); plt.xlabel("Δ‖E‖"); plt.ylabel("Overlap@5"); plt.tight_layout(); plt.savefig(PLOTS_DIR/"scatter_norm_vs_overlap.png", dpi=180); plt.close()

# HTML
s1 = json.loads((METRICS_DIR/"summary_stats_p1.json").read_text())
s3 = json.loads((METRICS_DIR/"summary_stats_p3.json").read_text())
html = f"""<!doctype html><html><head><meta charset="utf-8">
<title>Wake2Vec — Report {RUN_ID}</title>
<style>body{{font-family:"Times New Roman",serif;line-height:1.4}}.c{{max-width:900px;margin:2rem auto;padding:0 1rem}}</style></head>
<body><div class="c">
<h1>Wake2Vec — Morpheme-aware token expansion</h1>
<p><b>Run:</b> {RUN_ID}</p>
<ul>
<li><b>Phase 1</b> (embeddings-only warm-up): overlap@5 = {s1['mean_top5_overlap']:.3f}</li>
<li><b>Phase 3</b> (embed alignment++): overlap@5 = {s3['mean_top5_overlap']:.3f}, mean Δ‖E‖ = {s3['mean_norm_delta']:.4f}</li>
</ul>
<p>Tokenizer added {len(new_ids)} tokens; composed initial vectors saved.</p>
<img src="../runs/{RUN_ID}/plots/hist_overlap_top5.png" style="width:100%">
<img src="../runs/{RUN_ID}/plots/hist_norm_change.png" style="width:100%">
<img src="../runs/{RUN_ID}/plots/scatter_norm_vs_overlap.png" style="width:100%">
</div></body></html>"""
(REPORTS_DIR/"Wake2Vec_Report.html").write_text(html, encoding="utf-8")
print("Report:", REPORTS_DIR/"Wake2Vec_Report.html")

snapshot to drive

In [None]:
import shutil, tarfile
# copy adapters, metrics, plots, report
DST = PERSIST/"runs"/RUN_ID
DST.mkdir(parents=True, exist_ok=True)
shutil.copytree(METRICS_DIR, DST/"metrics", dirs_exist_ok=True)
shutil.copytree(PLOTS_DIR,   DST/"plots",   dirs_exist_ok=True)
(PERSIST/"reports").mkdir(parents=True, exist_ok=True)
shutil.copy(REPORTS_DIR/"Wake2Vec_Report.html", PERSIST/"reports"/f"Wake2Vec_Report_{RUN_ID}.html")

# tarball
(PERSIST/"archives").mkdir(parents=True, exist_ok=True)
with tarfile.open(PERSIST/"archives"/f"{RUN_ID}.tar.gz","w:gz") as tar:
    tar.add(str(RUN_DIR), arcname=f"runs/{RUN_ID}")
print("Snapshot saved under:", PERSIST)