<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/letsbuildthisthing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

did some digging and guess what i found, absolute gold, so here comes everybody

**MORPHEME-AWARE WAKE2VEC**

**Teaching TinyLlama Joyce's Generative Grammar**

Based on hand-compiled morphological analysis of Finnegans Wake

This notebook teaches compositional word formation via embedding arithmetic

In [None]:
!pip -q install python-docx

from pathlib import Path
import re, json, csv
from collections import defaultdict
from docx import Document

DOCX = Path("/content/morphemesraw.docx")
OUT_CSV = Path("/content/affixes.csv")
OUT_JSON = Path("/content/affixes.json")

if not DOCX.exists():
    raise FileNotFoundError(f"Upload your DOCX to {DOCX}")

# Normalisation helpers
def norm_token(s: str) -> str:
    s = (s or "").strip()
    s = re.sub(r"\s+", " ", s)
    # strip stray punctuation boundaries
    s = re.sub(r"^[^A-Za-z'-]+|[^A-Za-z'-]+$", "", s)
    return s

def is_header(line: str):
    """Return ('prefix','re-') or ('suffix','-ness') if line looks like a header."""
    l = line.strip().lower()
    # e.g. "prefix ad-" | "Prefix acro-" | "suffix -ness" | "Suffix a-1 and a-2"
    m1 = re.match(r'^(prefix)\s+([a-z]+)\-+(\d+)?', l)
    m2 = re.match(r'^(suffix)\s+\-+([a-z]+)(\d+)?', l)
    if m1:
        return "prefix", f"{m1.group(2)}-"
    if m2:
        return "suffix", f"-{m2.group(2)}"
    # also handle "prefix all-" with punctuation noise
    m3 = re.match(r'^(prefix)\s+([a-z\-]+)', l)
    if m3 and m3.group(2).endswith('-') and re.match(r'^[a-z\-]+-$', m3.group(2)):
        return "prefix", m3.group(2)
    m4 = re.match(r'^(suffix)\s+(\-[a-z\-]+)', l)
    if m4:
        return "suffix", m4.group(2)
    return None

doc = Document(str(DOCX))
lines = []
for para in doc.paragraphs:
    txt = para.text
    if txt is None:
        continue
    t = norm_token(txt)
    if not t:
        continue
    lines.append(t)

# Sweep through lines, collect examples under the last seen header until a new header
current = None  # (kind, morpheme)
buckets = defaultdict(list)  # (kind,morpheme) -> [examples]

for raw in lines:
    hdr = is_header(raw)
    if hdr:
        current = hdr  # set active bucket
        continue
    if not current:
        # Not under a header — skip noisy list blocks
        continue
    # explode lines that still contain counts or commas/spaces into tokens
    # keep A–Z strings (incl. hyphens, apostrophes), drop stray numbers
    tokens = re.findall(r"[A-Za-z][A-Za-z'’\-]*", raw)
    for tok in tokens:
        w = norm_token(tok)
        if not w:
            continue
        # avoid mistakenly re-adding the morpheme itself as an example
        _, morpheme = current
        if w.lower() == morpheme.strip('-').lower():
            continue
        key = current
        buckets[key].append(w)

# Deduplicate and lightly cap examples per morpheme (keeps order of first occurrence)
MAX_EXAMPLES = 150
clean = defaultdict(list)
seen_pair = set()
for (kind, morph), words in buckets.items():
    seen = set()
    out = []
    for w in words:
        wl = w.lower()
        if wl in seen:
            continue
        seen.add(wl)
        out.append(w)
        if len(out) >= MAX_EXAMPLES:
            break
    clean[(kind, morph)] = out

# Write CSV
with OUT_CSV.open("w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["kind","morpheme","example"])
    rows = 0
    for (kind, morph), lst in clean.items():
        for ex in lst:
            writer.writerow([kind, morph, ex])
            rows += 1

# Also write JSON in the simple shape your training uses
prefixes = {}
suffixes = {}
for (kind, morph), lst in clean.items():
    if kind == "prefix":
        prefixes[morph] = lst
    else:
        suffixes[morph] = lst

affixes_json = {"prefixes": prefixes, "suffixes": suffixes}
with OUT_JSON.open("w", encoding="utf-8") as f:
    json.dump(affixes_json, f, ensure_ascii=False, indent=2)

print(f"[done] wrote {OUT_CSV} and {OUT_JSON}")
print(f"  prefixes={len(prefixes)} | suffixes={len(suffixes)}")
# quick peek
for k, v in list(prefixes.items())[:5]:
    print("  prefix", k, "egs:", v[:6])
for k, v in list(suffixes.items())[:5]:
    print("  suffix", k, "egs:", v[:6])


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/253.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m245.8/253.0 kB[0m [31m8.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h[done] wrote /content/affixes.csv and /content/affixes.json
  prefixes=184 | suffixes=200
  prefix ab- egs: ['ove', 'Abaft', 'Abcd', 'Abe', 'Abecedeed', 'Abject']
  prefix abs- egs: ['Absolute']
  prefix acro- egs: ['Acropoll']
  prefix ad- egs: ['Added', 'Addedto', 'Addicted', 'addition', 'Adiaptotously', 'Admiracion']
  prefix all- egs: ['allbust', 'alliance', 'allinall', 'allmarken', 'allmysty', 'alloaf']
  suffix -a egs: ['Ada', 'Africa', 'Allsea', 'America', 'Anna', 'Aquila']
  suffix -ata egs: ['Cryptoconchoidsiphonostomata', 'Spezzata']
  suffix -able egs: ['alloilable', 'im-pugnable', 'Impermeable', 'inevita

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = "1337"

import torch
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.use_deterministic_algorithms(True)

import random, json, math, re
import numpy as np
from datetime import datetime
from pathlib import Path
from collections import defaultdict, Counter

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("Morpheme chaos mode activated")

Morpheme chaos mode activated


config

In [None]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CORPUS_PATH = "/content/fw.txt"
MORPHEME_DATA_PATH = "/content/affixes_terms.txt"

# Training params
BATCH_SIZE = 2
BLOCK_SIZE = 256
EPOCHS = 2
LR = 2e-5
WARMUP_RATIO = 0.05
WEIGHT_DECAY = 0.01
GRAD_ACCUM = 4
SAVE_STEPS = 200

# Morpheme chaos params
SYNTHETIC_PER_MORPHEME = 10  # Generate N examples per morpheme combo
COMPOSITION_ALPHA = 0.33     # Weight for prefix:root:suffix (0.33:0.34:0.33)
MORPHEME_NOISE = 0.05        # Add chaos to composed embeddings

# Output
RUN_ID = datetime.now().strftime("morpheme_wake_%Y%m%d_%H%M")
OUTDIR = Path(f"./runs/{RUN_ID}")
(OUTDIR / "results").mkdir(parents=True, exist_ok=True)
(OUTDIR / "checkpoints").mkdir(parents=True, exist_ok=True)

print(f"Run ID: {RUN_ID}")
print(f"Teaching TinyLlama Joyce's morphological grammar...")

Run ID: morpheme_wake_20251030_0017
Teaching TinyLlama Joyce's morphological grammar...


In [None]:
def load_corpus(path):
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Corpus not found: {p}")
    text = p.read_text(encoding="utf-8", errors="ignore")
    print(f"✓ Loaded corpus: {len(text)} chars")
    return text

FW_TEXT = load_corpus(CORPUS_PATH)

✓ Loaded corpus: 1364712 chars


parse morpheme data

In [None]:
def parse_morpheme_document(text):
    """
    Parse your hand-compiled morpheme analysis into structured data.
    Returns: {prefixes: {}, suffixes: {}, examples: {}}
    """
    data = {
        'prefixes': defaultdict(list),
        'suffixes': defaultdict(list),
        'infixes': defaultdict(list),
        'prefix_counts': Counter(),
        'suffix_counts': Counter(),
    }

    lines = text.split('\n')
    current_type = None
    current_morph = None

    for line in lines:
        line = line.strip()
        if not line:
            continue

        # Detect section headers
        if line.startswith('Prefix') or line.startswith('prefix'):
            current_type = 'prefix'
            # Extract morpheme: "Prefix ab- 13" -> "ab-"
            parts = line.split()
            if len(parts) >= 2:
                current_morph = parts[1].lower()
                count = int(parts[2]) if len(parts) > 2 and parts[2].isdigit() else 1
                data['prefix_counts'][current_morph] = count

        elif line.startswith('Suffix') or line.startswith('suffix'):
            current_type = 'suffix'
            parts = line.split()
            if len(parts) >= 2:
                current_morph = parts[1].lower()
                count = int(parts[2]) if len(parts) > 2 and parts[2].isdigit() else 1
                data['suffix_counts'][current_morph] = count

        elif line.startswith('Infix') or line.startswith('infix'):
            current_type = 'infix'
            parts = line.split()
            if len(parts) >= 2:
                current_morph = parts[1].lower()

        # Collect examples (lines that aren't headers)
        elif current_type and current_morph:
            # Skip lines with numbers only or special headers
            if line[0].isupper() and current_type in ['prefix', 'suffix']:
                continue

            # Clean example words
            word = line.split()[0] if line.split() else line
            word = word.strip('.,;:()[]{}\"\'')

            if word and len(word) > 1:
                if current_type == 'prefix':
                    data['prefixes'][current_morph].append(word)
                elif current_type == 'suffix':
                    data['suffixes'][current_morph].append(word)
                elif current_type == 'infix':
                    data['infixes'][current_morph].append(word)

    # Convert defaultdicts to regular dicts
    data['prefixes'] = dict(data['prefixes'])
    data['suffixes'] = dict(data['suffixes'])
    data['infixes'] = dict(data['infixes'])

    return data

# Load and parse your morpheme data
print("\n" + "="*60)
print("PARSING HAND-COMPILED MORPHEME DATA")
print("="*60)

morpheme_doc = Path(MORPHEME_DATA_PATH)
if morpheme_doc.exists():
    morpheme_text = morpheme_doc.read_text(encoding="utf-8", errors="ignore")
    MORPHEME_DATA = parse_morpheme_document(morpheme_text)
else:
    # Fallback: extract from the text you pasted inline
    print("⚠ Morpheme data file not found, using inline extraction...")
    # You can paste your dataset here as a string if needed
    MORPHEME_DATA = {
        'prefixes': {
            'ab-': ['above', 'abaft', 'abject', 'abler'],
            'anti-': ['anticipation', 'antipathies'],
            'circum-': ['circumvallator'],
            'hyper-': ['hyperchemical'],
            'sub-': ['subject', 'substrate', 'subordinating'],
        },
        'suffixes': {
            '-ation': ['acclammitation', 'anticipation', 'paupulation'],
            '-ous': ['delicious', 'precious', 'gracious'],
            '-ness': ['darkness', 'sweetness', 'softness'],
            '-ing': ['going', 'coming', 'being'],
        },
        'prefix_counts': Counter({'ab-': 13, 'anti-': 2, 'circum-': 1, 'hyper-': 1, 'sub-': 7}),
        'suffix_counts': Counter({'-ation': 38, '-ous': 49, '-ness': 28, '-ing': 257}),
    }

print(f"Parsed {len(MORPHEME_DATA['prefixes'])} prefixes")
print(f"Parsed {len(MORPHEME_DATA['suffixes'])} suffixes")
print(f"\nTop prefixes by frequency:")
for morph, count in MORPHEME_DATA['prefix_counts'].most_common(10):
    print(f"  {morph}: {count}")
print(f"\nTop suffixes by frequency:")
for morph, count in MORPHEME_DATA['suffix_counts'].most_common(10):
    print(f"  {morph}: {count}")


PARSING HAND-COMPILED MORPHEME DATA
✓ Parsed 0 prefixes
✓ Parsed 1 suffixes

Top prefixes by frequency:

Top suffixes by frequency:
  –‘s: 1


In [None]:
# ==== Build MORPHEME_DATA from a flat term list (affixes_terms.txt) ====
import re
from pathlib import Path
from collections import defaultdict

AFFIX_TXT = Path("/content/affixes_terms.txt")  # one term per line
EXAMPLES_PER_MORPHEME = 120                     # cap examples per morpheme

if not AFFIX_TXT.exists():
    raise FileNotFoundError(f"Missing {AFFIX_TXT}. Upload your affixes_terms.txt first.")

# --- Curated affix lists (extend as you like) ---
PREFIXES = [
    "anti-","ante-","arch-","auto-","bi-","bio-","co-","con-","contra-","counter-",
    "crypto-","de-","dis-","down-","en-","em-","ex-","extra-","fore-","geo-","hetero-",
    "homo-","hyper-","hypo-","il-","im-","in-","inter-","intra-","ir-","macro-",
    "mega-","meta-","micro-","mid-","mis-","mono-","multi-","neo-","non-","omni-",
    "over-","pan-","para-","peri-","poly-","post-","pre-","pro-","proto-","pseudo-",
    "re-","semi-","sub-","super-","supra-","sur-","tele-","trans-","tri-","ultra-",
    "un-","under-","uni-","up-","vice-",
    # Joycean-friendly add-ons
    "quasi-","infra-","intro-","out-","sous-","uber-"
]
PREFIXES.sort(key=len, reverse=True)

SUFFIXES = [
    "-ability","-ibility","-ation","-ition","-ication","-ization",
    "-ment","-ness","-less","-ful","-able","-ible","-ish","-ism","-ist","-ity","-ety",
    "-ive","-ative","-tive","-al","-ial","-ual","-ary","-ory","-ature",
    "-ous","-eous","-ious","-esque","-ific","-logue","-logy","-ology","-ography",
    "-ship","-hood","-ward","-wards","-wise","-y","-ly","-er","-or","-eer","-eur",
    "-ette","-let","-ling","-kin","-ance","-ence","-ancy","-ency","-ure",
    "-ium","-um","-arium","-orium","-dom",
    # (Optional) inflectional:
    "-ing","-ed","-en","-s","-es",
    # Joycean-friendly
    "-scape","-some","-smith","-most","-worthy","-gate","-tron","-plex","-polis"
]
SUFFIXES.sort(key=len, reverse=True)

# --- helpers ---
def _clean_line(s: str) -> str:
    s = (s or "").strip()
    s = re.sub(r"\s+", " ", s)
    # keep letters, hyphens, apostrophes; strip outer junk
    s = re.sub(r"^[^A-Za-z'’-]+|[^A-Za-z'’-]+$", "", s)
    s = s.replace("’", "'")
    return s

def _longest_prefix(word: str):
    wl = word.lower()
    for p in PREFIXES:
        core = p[:-1]  # drop trailing '-'
        if wl.startswith(core) and len(wl) > len(core)+0:  # allow short roots
            return p
    return None

def _longest_suffix(word: str):
    wl = word.lower()
    for s in SUFFIXES:
        core = s[1:]  # drop leading '-'
        if wl.endswith(core) and len(wl) > len(core)+0:
            return s
    return None

# --- load terms ---
terms = []
with AFFIX_TXT.open("r", encoding="utf-8", errors="ignore") as f:
    for line in f:
        t = _clean_line(line)
        if not t or not re.search(r"[A-Za-z]", t):
            continue
        # skip accidental section headers like "Prefix", "Suffix" if present
        if re.match(r"(?i)^(prefix|suffix|infix)\b", t):
            continue
        terms.append(t)

# --- build morpheme -> examples maps ---
prefix_map = defaultdict(list)
suffix_map = defaultdict(list)

for w in terms:
    p = _longest_prefix(w)
    s = _longest_suffix(w)
    if p:
        prefix_map[p].append(w)
    if s:
        suffix_map[s].append(w)

# --- dedupe & cap examples ---
def _dedupe_cap(d):
    out = {}
    for k, lst in d.items():
        seen = set()
        keep = []
        for w in lst:
            wl = w.lower()
            if wl in seen:
                continue
            seen.add(wl)
            keep.append(w)
            if len(keep) >= EXAMPLES_PER_MORPHEME:
                break
        out[k] = keep
    return out

prefix_map = _dedupe_cap(prefix_map)
suffix_map = _dedupe_cap(suffix_map)

# --- final structure for your functions ---
MORPHEME_DATA = {
    "prefixes": dict(prefix_map),
    "suffixes": dict(suffix_map),
}

# --- preview ---
def _peek(d, n=5):
    return {k: d[k][:min(len(d[k]), 3)] for k in list(d.keys())[:n]}

print(f"[MORPHEME_DATA] prefixes={len(MORPHEME_DATA['prefixes'])} | suffixes={len(MORPHEME_DATA['suffixes'])}")
print("  sample prefixes:", _peek(MORPHEME_DATA["prefixes"]))
print("  sample suffixes:", _peek(MORPHEME_DATA["suffixes"]))


[MORPHEME_DATA] prefixes=57 | suffixes=63
  sample prefixes: {'anti-': ['anti-cipation', 'Anticipation', 'antipathies'], 'auto-': ['Autotone', 'autotune'], 'co-': ['Cocoa', 'columna', 'Coal'], 'con-': ['Conna', 'Constellatria', 'conceal'], 'crypto-': ['Cryptoconchoidsiphonostomata']}
  sample suffixes: {'-s': ['Aas', 'accacians', 'Accidents'], '-y': ['Abby', 'Accuracy', 'aisy'], '-ed': ['abecedeed', 'accorded', 'Achamed'], '-er': ['abler', 'Acquiester', 'admirer'], '-al': ['Accidental', 'Apersonal', 'Appeal']}


 LOAD MODEL & TOKENIZER

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

set_seed(SEED)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

print(f"\nModel: {MODEL_NAME}")
print(f"Device: {DEVICE}")
print(f"Initial vocab: {len(tok)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]


Model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Device: cuda
Initial vocab: 32000


In [None]:
import re
from pathlib import Path
from collections import defaultdict

def _clean_line(s: str) -> str:
    s = (s or "").strip()
    s = re.sub(r"\b\d+\b$", "", s).strip()                 # drop trailing counts
    s = re.sub(r"\((sic|decap)\)", "", s, flags=re.I).strip()
    return s

def _normalise_affix(kind: str, aff: str) -> str:
    aff = (aff or "").replace("–","-").strip()
    if not aff: return ""
    if kind == "prefix":
        return aff if aff.endswith("-") else (aff + "-")
    if kind == "suffix":
        return aff if aff.startswith("-") else ("-" + aff)
    return aff

prefix_map = MORPHEME_DATA.get("prefixes", {})
suffix_map = MORPHEME_DATA.get("suffixes", {})

def _dedupe_cap(d, limit):
    out = {}
    for k, lst in d.items():
        seen = set()
        uniq = []
        for w in lst:
            wl = w.lower()
            if wl not in seen:
                seen.add(wl)
                uniq.append(w)
            if len(uniq) >= limit:
                break
        out[k] = uniq
    return out

def _peek(d, n=5):
    # Ensure we only sample from keys that actually have examples
    valid_keys = [k for k in d.keys() if d[k]]
    sample_keys = random.sample(valid_keys, min(n, len(valid_keys))) if valid_keys else []
    return {k: d[k][:min(len(d[k]),3)] for k in sample_keys}


print(f"[MORPHEME_DATA] prefixes={len(MORPHEME_DATA.get('prefixes', {}))} | suffixes={len(MORPHEME_DATA.get('suffixes', {}))}")
print("  sample prefixes:", _peek(MORPHEME_DATA.get("prefixes", {})))
print("  sample suffixes:", _peek(MORPHEME_DATA.get("suffixes", {})))

[MORPHEME_DATA] prefixes=57 | suffixes=63
  sample prefixes: {'out-': ['Outnullused', 'Outer', 'Outing'], 'semi-': ['semination', 'semicolonials'], 'fore-': ['Forest', 'Forecast', 'Foretellers'], 'homo-': ['homo-gallant', 'homoheatherous', 'Homoid'], 'mega-': ['Megacene', 'Megalopolis']}
  sample suffixes: {'-ication': ['Fornication', 'Gratification', 'Intoxication'], '-ness': ['deleteriousness of decorousness', 'badness', 'Breadchestviousness'], '-ance': ['acquointance', 'advance', 'Alliance'], '-ette': ['liffeyette', 'pepette', 'pette'], '-hood': ['childhood']}


In [None]:
import torch
import torch.nn.functional as F

# sensible defaults (honour existing globals if defined elsewhere)
COMPOSITION_ALPHA = float(globals().get("COMPOSITION_ALPHA", 0.2))
MORPHEME_NOISE    = float(globals().get("MORPHEME_NOISE", 0.03))
DEVICE            = next(model.parameters()).device

def _tok_ids(text, tokenizer):
    return tokenizer(text, add_special_tokens=False)["input_ids"]

def _vec_mean(ids, emb_matrix):
    if not ids:
        return None
    vecs = emb_matrix[torch.as_tensor(ids, device=emb_matrix.device)]
    return vecs.mean(dim=0)

def _single_token_id(text, tokenizer):
    ids = _tok_ids(text, tokenizer)
    return ids[0] if len(ids) == 1 else None

def _kind_of(morpheme: str):
    # prefixes end with '-', suffixes start with '-'
    if morpheme.endswith('-'): return "prefix"
    if morpheme.startswith('-'): return "suffix"
    return None

def _nearest_morpheme_backoff(morpheme, emb_matrix, tokenizer, k=8):
    """
    Try to find another morpheme of the same kind with examples,
    take the average embedding across a few of its examples.
    """
    kind = _kind_of(morpheme)
    if kind is None:
        return None

    pool = MORPHEME_DATA["prefixes"] if kind == "prefix" else MORPHEME_DATA["suffixes"]
    if not pool:
        return None

    # crude lexical proximity: same first letters (after stripping hyphen)
    stem = morpheme.strip('-').lower()
    candidates = []
    for m, exs in pool.items():
        if not exs:
            continue
        score = 0
        mstem = m.strip('-').lower()
        # shared prefix length heuristic
        for a, b in zip(stem, mstem):
            if a == b: score += 1
            else: break
        candidates.append((score, m, exs))

    if not candidates:
        return None
    candidates.sort(key=lambda x: x[0], reverse=True)
    # take a few top candidates’ example words
    ex_words = []
    for _, _, exs in candidates[:k]:
        ex_words.extend(exs[:5])
    if not ex_words:
        return None

    emb_list = []
    for w in ex_words:
        ids = _tok_ids(w.lower(), tokenizer)
        v = _vec_mean(ids, emb_matrix)
        if v is not None:
            emb_list.append(v)
    if emb_list:
        return torch.stack(emb_list).mean(dim=0)
    return None

def find_morpheme_embedding(morpheme, model, tokenizer):
    """
    Strategy:
    1) If morpheme is a single tokenizer piece, return that vector.
    2) Else average embeddings of example words containing it (token-averaged).
    3) Else average embeddings of morpheme's own subtokens.
    4) Else back off to a similar morpheme's examples.
    5) Else random normal scaled to emb std.
    """
    emb_matrix = model.get_input_embeddings().weight.data

    # --- 1) single-token morpheme (e.g., some BPEs may have "re" + "-" separate; we require length == 1)
    stid = _single_token_id(morpheme, tokenizer)
    if stid is not None:
        return emb_matrix[stid].clone()

    # --- 2) average example words for this morpheme (from MORPHEME_DATA)
    examples = []
    if morpheme in MORPHEME_DATA.get('prefixes', {}):
        examples = MORPHEME_DATA['prefixes'][morpheme][:32]
    elif morpheme in MORPHEME_DATA.get('suffixes', {}):
        examples = MORPHEME_DATA['suffixes'][morpheme][:32]

    if examples:
        emb_list = []
        for w in examples:
            ids = _tok_ids(w.lower(), tokenizer)
            v = _vec_mean(ids, emb_matrix)
            if v is not None:
                emb_list.append(v)
        if emb_list:
            return torch.stack(emb_list).mean(dim=0)

    # --- 3) average the morpheme’s own subtokens (strip the hyphen so BPE sees letters)
    morph_text = morpheme.strip('-').lower()
    ids = _tok_ids(morph_text, tokenizer)
    v = _vec_mean(ids, emb_matrix)
    if v is not None:
        return v

    # --- 4) neighbor backoff from same kind
    v = _nearest_morpheme_backoff(morpheme, emb_matrix, tokenizer)
    if v is not None:
        return v

    # --- 5) final fallback: random normal scaled to overall emb std
    return torch.randn(emb_matrix.shape[1], device=emb_matrix.device) * emb_matrix.std()

def compose_morpheme_embedding(prefix, root, suffix, model, tokenizer):
    """
    E(word) = α * E(prefix) + β * E(root) + γ * E(suffix),  where β = 1 - 2α
    If a component is missing, its vector is a zero vector.
    Root uses token-mean; if missing, small random normal.
    """
    emb_matrix = model.get_input_embeddings().weight.data
    dim = emb_matrix.shape[1]
    alpha = COMPOSITION_ALPHA
    beta  = 1.0 - 2.0 * alpha
    gamma = alpha

    # prefix / suffix vecs (zero if None/empty)
    if prefix:
        prefix_vec = find_morpheme_embedding(prefix, model, tokenizer)
    else:
        prefix_vec = torch.zeros(dim, device=emb_matrix.device)

    if suffix:
        suffix_vec = find_morpheme_embedding(suffix, model, tokenizer)
    else:
        suffix_vec = torch.zeros(dim, device=emb_matrix.device)

    # root vec: average token pieces (more robust than single-id lookup)
    root_ids = _tok_ids((root or "").lower(), tokenizer)
    if root_ids:
        root_vec = _vec_mean(root_ids, emb_matrix)
    else:
        root_vec = None
    if root_vec is None:
        root_vec = torch.randn(dim, device=emb_matrix.device) * 0.02

    composed = alpha * prefix_vec + beta * root_vec + gamma * suffix_vec

    # small noise for diversity/stability
    if MORPHEME_NOISE > 0:
        std = composed.detach().std().clamp(min=1e-6)
        composed = composed + torch.randn_like(composed) * (MORPHEME_NOISE * std)

    return composed


gen synthetic wake words via morpheme combination

In [None]:
import random
from collections import Counter

# --- ensure counts exist from MORPHEME_DATA['prefixes'/'suffixes'] ---
def _ensure_morpheme_counts():
    px = {m: len(exs) for m, exs in MORPHEME_DATA.get("prefixes", {}).items()}
    sx = {m: len(exs) for m, exs in MORPHEME_DATA.get("suffixes", {}).items()}
    # fallback if empty
    if not px:
        px = {"re-": 1, "un-": 1, "in-": 1}
    if not sx:
        sx = {"-ness": 1, "-al": 1, "-ity": 1}
    MORPHEME_DATA["prefix_counts"] = px
    MORPHEME_DATA["suffix_counts"] = sx

_ensure_morpheme_counts()

def generate_morpheme_words(n_samples=1000, p_prefix=0.7, p_suffix=0.8, roots=None, dedupe=True):
    """
    Generate synthetic 'Wake-ish' words by composing prefix + root + suffix.
    Frequencies are proportional to #examples recorded for each morpheme.
    """
    generated = []

    prefixes = list(MORPHEME_DATA["prefix_counts"].keys())
    suffixes = list(MORPHEME_DATA["suffix_counts"].keys())
    pw = [max(1, MORPHEME_DATA["prefix_counts"][p]) for p in prefixes]
    sw = [max(1, MORPHEME_DATA["suffix_counts"][s]) for s in suffixes]

    if not roots:
        roots = [
            'dream','river','thunder','word','night','day','wake','sleep',
            'fire','water','time','man','woman','king','queen','stone',
            'tree','moon','sun','star','wind','rain','storm','cloud',
            'book','letter','voice','sound','song','dance','walk','run'
        ]

    seen = set()
    for _ in range(n_samples):
        use_p = (random.random() < p_prefix) and len(prefixes) > 0
        use_s = (random.random() < p_suffix) and len(suffixes) > 0

        prefix = random.choices(prefixes, weights=pw, k=1)[0] if use_p else None
        suffix = random.choices(suffixes, weights=sw, k=1)[0] if use_s else None
        root   = random.choice(roots)

        # build word string without hyphen artifacts
        def strip_hy(s, is_prefix):
            if not s: return ""
            return s[:-1] if is_prefix else s[1:]  # drop trailing '-' for prefix; leading '-' for suffix

        word = f"{strip_hy(prefix, True)}{root}{strip_hy(suffix, False)}"

        if dedupe:
            if word in seen:
                continue
            seen.add(word)

        generated.append({
            "word": word,
            "prefix": prefix,
            "root": root,
            "suffix": suffix
        })

    return generated

print("\n" + "="*60)
print("GENERATING SYNTHETIC WAKE WORDS")
print("="*60)

synthetic_words = generate_morpheme_words(n_samples=500)
print(f"✓ Generated {len(synthetic_words)} morphological neologisms")

print("\nExamples:")
for w in synthetic_words[:20]:
    print(f"  {w['word']:20s} ({w['prefix'] or 'Ø'} + {w['root']} + {w['suffix'] or 'Ø'})")


GENERATING SYNTHETIC WAKE WORDS
✓ Generated 471 morphological neologisms

Examples:
  inlettersmith        (in- + letter + -smith)
  cloudor              (Ø + cloud + -or)
  innight              (in- + night + Ø)
  misstares            (mis- + star + -es)
  imdance              (im- + dance + Ø)
  emdanceling          (em- + dance + -ling)
  unthunder            (un- + thunder + Ø)
  misword              (mis- + word + Ø)
  consonger            (con- + song + -er)
  wind                 (Ø + wind + Ø)
  enqueen              (en- + queen + Ø)
  unkingor             (un- + king + -or)
  reworded             (re- + word + -ed)
  riverly              (Ø + river + -ly)
  stormity             (Ø + storm + -ity)
  treeor               (Ø + tree + -or)
  codayed              (co- + day + -ed)
  overword             (over- + word + Ø)
  imthunderary         (im- + thunder + -ary)
  constormance         (con- + storm + -ance)


token injection

In [None]:
import math, random, torch
from datasets import Dataset

# 1) Make tiny Joyce-ish sentences so the new tokens appear in context
def lines_from_synth(synth, per_word=3):
    patt = [
        "He said {w} and smiled at the rivery din.",
        "They hummed {w} through the nightlong thunder.",
        "A {w} fell between letter and time.",
        "Write it as {w}, he urged, once and again.",
        "Under cloud and over stone, {w} kept talking."
    ]
    out = []
    for item in synth:
        w = item["word"]
        for _ in range(per_word):
            out.append(random.choice(patt).format(w=w))
    return out

aug_lines = lines_from_synth(synthetic_words, per_word=3)
print(f"[augment] {len(aug_lines)} synthetic lines from {len(synthetic_words)} words")

# 2) Add only words that aren't already single tokens; init embeddings by morpheme composition
def is_single_token(tokenizer, s):
    return len(tokenizer(s, add_special_tokens=False)["input_ids"]) == 1

to_add = [x for x in synthetic_words if not is_single_token(tok, x["word"])]
new_tokens = [x["word"] for x in to_add]

n_before = len(tok)
added = tok.add_tokens(new_tokens, special_tokens=False)
model.resize_token_embeddings(len(tok))
print(f"[tokenizer] added {added} new tokens (from {len(new_tokens)} candidates)")

# compose vectors for the *added* words and write them into the embedding matrix (tied to lm_head)
with torch.no_grad():
    emb = model.get_input_embeddings().weight
    # map word -> new id (only for those actually added)
    vocab = tok.get_vocab()
    inv = {t:i for t,i in vocab.items()}  # token -> id
    wrote = 0
    for item in to_add:
        w = item["word"]
        tid = inv.get(w, None)
        if tid is None:
            continue  # not actually added (maybe was already in vocab as 1-piece)
        p, r, s = item.get("prefix"), item.get("root"), item.get("suffix")
        vec = compose_morpheme_embedding(p, r, s, model, tok)
        emb[tid].copy_(vec)
        wrote += 1

    # ensure lm_head is tied (GPT-2 style)
    if hasattr(model, "lm_head") and model.lm_head.weight.data_ptr() != emb.data_ptr():
        model.lm_head.weight = torch.nn.Parameter(emb)

print(f"[init] wrote composed embeddings for {wrote} new tokens")

# 3) Build combined training text (FW + synthetic)
if 'FW_TEXT' not in globals() or not FW_TEXT:
    raise RuntimeError("FW_TEXT is not loaded. Set FW_TEXT to your Finnegans Wake text string first.")

COMBINED_TEXT = FW_TEXT + "\n" + "\n".join(aug_lines)

# Convert to block dataset for causal LM
def make_blocks_from_text(text, tokenizer, block_size=256):
    ids = tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
    n = len(ids) // block_size
    if n == 0:
        raise ValueError(f"Text too short: {len(ids)} ids for block_size={block_size}")
    ids = ids[: n * block_size]
    arr = torch.tensor(ids, dtype=torch.int32).view(n, block_size).tolist()
    return Dataset.from_dict({"input_ids": arr})

BLOCK_SIZE = int(globals().get("BLOCK_SIZE", 256))
ds = make_blocks_from_text(COMBINED_TEXT, tok, BLOCK_SIZE)

def _add_labels(b): return {"labels": b["input_ids"]}
N = len(ds)
cut = int(0.9 * N) if N > 10 else N
train = ds.select(range(cut)).map(_add_labels, batched=True)
valid = ds.select(range(cut, N)).map(_add_labels, batched=True)
print(f"[dataset] blocks total={N} | train={len(train)} | valid={len(valid)}")

# quick peek: confirm a few new tokens are single-piece now
probe = [w for w in new_tokens[:10]]
lens = {w: len(tok(w, add_special_tokens=False)["input_ids"]) for w in probe}
print("[probe] tokenization lengths (expect 1):", lens)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


[augment] 1413 synthetic lines from 471 words


The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


[tokenizer] added 447 new tokens (from 447 candidates)
[init] wrote composed embeddings for 447 new tokens


Token indices sequence length is longer than the specified maximum sequence length for this model (445614 > 2048). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/1566 [00:00<?, ? examples/s]

Map:   0%|          | 0/174 [00:00<?, ? examples/s]

[dataset] blocks total=1740 | train=1566 | valid=174
[probe] tokenization lengths (expect 1): {'inlettersmith': 1, 'cloudor': 1, 'innight': 1, 'misstares': 1, 'imdance': 1, 'emdanceling': 1, 'unthunder': 1, 'misword': 1, 'consonger': 1, 'enqueen': 1}


In [None]:
# Turn OFF mean-resizing if your HF version supports it (keeps own init in control)
import inspect

def resize_tok_emb_no_mean(model, new_size):
    fn = model.resize_token_embeddings
    sig = inspect.signature(fn)
    if "mean_resizing" in sig.parameters:
        return fn(new_size, mean_resizing=False)
    return fn(new_size)

# Example usage after tokenizer.add_tokens(...)
n_before = len(tok)
added = tok.add_tokens(new_tokens, special_tokens=False)
resize_tok_emb_no_mean(model, len(tok))

# (Re)tie lm_head to embeddings (GPT-2 family)
with torch.no_grad():
    emb = model.get_input_embeddings().weight
    if hasattr(model, "lm_head") and model.lm_head.weight.data_ptr() != emb.data_ptr():
        model.lm_head.weight = torch.nn.Parameter(emb)


## SYNTHETIC TRAINING DATA GENERATION

pre training snap

In [None]:
# === PRE-TRAINING SNAPSHOT ===
from pathlib import Path
import json, torch, hashlib
from datetime import datetime

SNAP_DIR = Path(OUTDIR) / f"pretrain_snapshot_{datetime.now().strftime('%Y%m%d_%H%M')}"
SNAP_DIR.mkdir(parents=True, exist_ok=True)

# 1) Save tokenizer (captures the 447 new tokens you've added)
tok.save_pretrained(SNAP_DIR)

# 2) Save model weights (before any new training)
model.save_pretrained(SNAP_DIR)

# 3) Save vocab + new_tokens + quick embedding stats
vocab = tok.get_vocab()
inv_vocab = {i:t for t,i in vocab.items()}
emb = model.get_input_embeddings().weight.detach().float().cpu()

stats = {
    "vocab_size": len(vocab),
    "embedding_dim": emb.shape[1],
    "embedding_mean": float(emb.mean()),
    "embedding_std":  float(emb.std()),
}

# store an md5 hash over the embedding tensor for later diff checks
md5 = hashlib.md5(emb.numpy().tobytes()).hexdigest()
stats["embedding_md5"] = md5

# if you kept `new_tokens` from earlier cell, persist them; else store empty list
try:
    new_tokens  # noqa
except NameError:
    new_tokens = []

(Path(SNAP_DIR / "embedding_stats.json")).write_text(json.dumps(stats, indent=2))
(Path(SNAP_DIR / "new_tokens.json")).write_text(json.dumps(new_tokens, indent=2))

print("[snapshot] saved to:", SNAP_DIR)
print(stats)


[snapshot] saved to: runs/morpheme_wake_20251030_0017/pretrain_snapshot_20251030_0100
{'vocab_size': 32445, 'embedding_dim': 2048, 'embedding_mean': -3.6155972793494584e-07, 'embedding_std': 0.014816503040492535, 'embedding_md5': '31866c7862ffb554e4e362cbad659e69'}


In [None]:
def generate_morpheme_sentences(word_data, per_word=5):
    """
    Generate training sentences that showcase morphological patterns.
    """
    patterns = [
        "The {word} of {root} echoes through the Wake.",
        "By {word} and by {root}, the river flows.",
        "In the {word} of night, {root} speaks.",
        "From {root} to {word}, the tale unwinds.",
        "He spoke of {word} as if {root} remembered.",
        "{word} upon {word}, the {root} multiplies.",
        "Through {word} and beyond {root}, voices drift.",
        "The {word} contains the {root} contains the word.",
        "Call it {word}, call it {root}-become-language.",
        "Riverrun past {word} and {root} from swerve of shore.",
    ]

    sentences = []
    for item in word_data:
        for _ in range(per_word):
            pattern = random.choice(patterns)
            sentence = pattern.format(word=item['word'], root=item['root'])
            sentences.append(sentence)

    return sentences

synthetic_sentences = generate_morpheme_sentences(synthetic_words[:200], per_word=SYNTHETIC_PER_MORPHEME)
random.shuffle(synthetic_sentences)

print(f"\n✓ Generated {len(synthetic_sentences)} training sentences")
print("\nSample sentences:")
for s in synthetic_sentences[:5]:
    print(f"  {s}")

# Combine with original Wake text
COMBINED_TEXT = FW_TEXT + "\n" + "\n".join(synthetic_sentences)
print(f"\n✓ Combined corpus: {len(COMBINED_TEXT)} chars")

# Save generated words for analysis
with open(OUTDIR / "results" / "generated_morpheme_words.json", "w") as f:
    json.dump(synthetic_words, f, indent=2)


✓ Generated 2000 training sentences

Sample sentences:
  From letter to inlettersmith, the tale unwinds.
  From king to unkingor, the tale unwinds.
  The emsoundist of sound echoes through the Wake.
  In the redayal of night, day speaks.
  semiqueens upon semiqueens, the queen multiplies.

✓ Combined corpus: 1457308 chars


In [None]:
import random

# knobs
N_SYN_WORDS   = 800       # number of neologisms to sample
P_PREFIX      = 0.72      # prob of using a prefix
P_SUFFIX      = 0.82      # prob of using a suffix
PER_WORD_LINES= 3         # how many sentences per new word

# 1) generate synthetic words (uses MORPHEME_DATA, created earlier)
synthetic_words = generate_morpheme_words(
    n_samples=N_SYN_WORDS,
    p_prefix=P_PREFIX,
    p_suffix=P_SUFFIX,
    dedupe=True
)

print(f"[synthetic] words={len(synthetic_words)} (p_prefix={P_PREFIX}, p_suffix={P_SUFFIX})")
print("  e.g.:", ", ".join([w["word"] for w in synthetic_words[:10]]))

# 2) convert to short Joyce-ish sentences
def lines_from_synth(synth, per_word=PER_WORD_LINES):
    patt = [
        "He said {w} and smiled at the rivery din.",
        "They hummed {w} through the nightlong thunder.",
        "A {w} fell between letter and time.",
        "Write it as {w}, he urged, once and again.",
        "Under cloud and over stone, {w} kept talking."
    ]
    out = []
    for item in synth:
        w = item["word"]
        for _ in range(per_word):
            out.append(random.choice(patt).format(w=w))
    return out

aug_lines = lines_from_synth(synthetic_words, per_word=PER_WORD_LINES)
print(f"[synthetic] lines={len(aug_lines)} (≈ {PER_WORD_LINES} per word)")


[synthetic] words=734 (p_prefix=0.72, p_suffix=0.82)
  e.g.: conmanes, presounder, soundity, codayor, outdanceist, maned, corunous, coraines, consongen, destarment
[synthetic] lines=2202 (≈ 3 per word)


data set prep

In [None]:
# === DATASET PREP: add tokens, init vectors, build block dataset ===
import inspect, torch
from datasets import Dataset

BLOCK_SIZE  = int(globals().get("BLOCK_SIZE", 256))
SAVE_STEPS  = int(globals().get("SAVE_STEPS", 200))

def resize_tok_emb_no_mean(model, new_size):
    fn = model.resize_token_embeddings
    sig = inspect.signature(fn)
    if "mean_resizing" in sig.parameters:
        return fn(new_size, mean_resizing=False)
    return fn(new_size)

def is_single_token(tokenizer, s):
    return len(tokenizer(s, add_special_tokens=False)["input_ids"]) == 1

# 1) determine which synthetic words need to be added
to_add = [x for x in synthetic_words if not is_single_token(tok, x["word"])]
new_tokens = [x["word"] for x in to_add]
added = tok.add_tokens(new_tokens, special_tokens=False)
resize_tok_emb_no_mean(model, len(tok))
print(f"[tokenizer] candidates={len(new_tokens)} | actually_added={added}")

# 2) write composed embeddings only for actually-added tokens
with torch.no_grad():
    emb = model.get_input_embeddings().weight
    vocab = tok.get_vocab()
    wrote = 0
    for item in to_add:
        tid = vocab.get(item["word"])
        if tid is None:
            continue
        vec = compose_morpheme_embedding(item.get("prefix"), item.get("root"), item.get("suffix"), model, tok)
        emb[tid].copy_(vec)
        wrote += 1
    if hasattr(model, "lm_head") and model.lm_head.weight.data_ptr() != emb.data_ptr():
        model.lm_head.weight = torch.nn.Parameter(emb)
print(f"[init] morpheme-composed vectors written for {wrote} tokens")

# 3) build combined text from FW + synthetic lines
assert 'FW_TEXT' in globals() and FW_TEXT, "FW_TEXT missing; load your Finnegans Wake text string."
COMBINED_TEXT = FW_TEXT + "\n" + "\n".join(aug_lines)

# avoid any full-sequence forwards later:
def make_blocks_from_text(text, tokenizer, block_size=BLOCK_SIZE):
    ids = tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
    n = len(ids) // block_size
    if n == 0:
        raise ValueError(f"Text too short: {len(ids)} ids for block_size={block_size}")
    ids = ids[: n * block_size]
    import numpy as np
    arr = np.array(ids, dtype=np.int32).reshape(n, block_size)
    return Dataset.from_dict({"input_ids": arr.tolist()})

ds = make_blocks_from_text(COMBINED_TEXT, tok, BLOCK_SIZE)

def _add_labels(b): return {"labels": b["input_ids"]}
N = len(ds)
cut = int(0.9 * N) if N > 10 else N
train = ds.select(range(cut)).map(_add_labels, batched=True)
valid = ds.select(range(cut, N)).map(_add_labels, batched=True)

print(f"[dataset] blocks total={N} | train={len(train)} | valid={len(valid)} | block_size={BLOCK_SIZE}")

# quick probe: new tokens are single-piece now
probe = new_tokens[:10]
lens = {w: len(tok(w, add_special_tokens=False)["input_ids"]) for w in probe}
print("[probe] tokenization lengths (expect 1):", lens)


[tokenizer] candidates=654 | actually_added=654
[init] morpheme-composed vectors written for 654 tokens


Map:   0%|          | 0/1596 [00:00<?, ? examples/s]

Map:   0%|          | 0/178 [00:00<?, ? examples/s]

[dataset] blocks total=1774 | train=1596 | valid=178 | block_size=256
[probe] tokenization lengths (expect 1): {'conmanes': 1, 'presounder': 1, 'codayor': 1, 'outdanceist': 1, 'coraines': 1, 'consongen': 1, 'destarment': 1, 'revoicely': 1, 'overwatering': 1, 'miswomaning': 1}


pre train snap shot

In [None]:
def get_embedding_snapshot(words, model, tokenizer, name="snapshot"):
    model.eval()
    emb_matrix = model.get_input_embeddings().weight.data
    emb_norm = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)

    snapshot = {"name": name, "vocab_size": len(tokenizer), "words": {}}

    for word_item in words[:50]:
        word = word_item['word']
        tid = tokenizer.convert_tokens_to_ids(word)
        if tid == tokenizer.unk_token_id:
            continue

        word_emb_norm = emb_norm[tid]
        sims = torch.matmul(word_emb_norm.unsqueeze(0), emb_norm.T)[0]
        top_k = torch.topk(sims, 11)

        neighbors = []
        for idx, sim in zip(top_k.indices[1:], top_k.values[1:]):
            neighbors.append({
                "token": tokenizer.convert_ids_to_tokens(idx.item()),
                "sim": round(sim.item(), 4)
            })

        snapshot["words"][word] = {
            "token_id": tid,
            "composition": f"{word_item['prefix'] or 'Ø'}+{word_item['root']}+{word_item['suffix'] or 'Ø'}",
            "embedding_norm": round(emb_matrix[tid].norm().item(), 4),
            "top_neighbors": neighbors[:10]
        }

    return snapshot

print("\n" + "="*60)
print("PRE-TRAINING SNAPSHOT")
print("="*60)

pre_snapshot = get_embedding_snapshot(synthetic_words, model, tok, "pre_morpheme")
with open(OUTDIR / "results" / "pre_morpheme_snapshot.json", "w") as f:
    json.dump(pre_snapshot, f, indent=2)

print(f"✓ Pre-training snapshot: {len(pre_snapshot['words'])} words")


PRE-TRAINING SNAPSHOT
✓ Pre-training snapshot: 49 words


training args

In [None]:
# === OOM-HARDENED TRAIN (FP32 activations, adamw_bnb_8bit optimizer) ===
from pathlib import Path
import os, random, json, torch, torch.nn.functional as F
from transformers import (
    TrainingArguments, Trainer, DataCollatorForLanguageModeling,
    EarlyStoppingCallback
)
import transformers

# 0) deps + CUDA alloc hygiene
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
try:
    import bitsandbytes as bnb  # noqa: F401
except Exception:
    # Colab-friendly install
    import subprocess, sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "bitsandbytes>=0.43.0"])
    import bitsandbytes as bnb  # noqa: F401

torch.cuda.empty_cache()

# 1) knobs (tune BATCH_SIZE/GRAD_ACCUM if you still see OOM)
OUTDIR        = Path(globals().get("OUTDIR", "runs/wake2vec_run"))
BATCH_SIZE    = int(globals().get("BATCH_SIZE", 8))   # try 4/2/1 if needed
GRAD_ACCUM    = int(globals().get("GRAD_ACCUM", 2))
EPOCHS_FULL   = int(globals().get("EPOCHS_FULL", 2))
LR_FULL       = float(globals().get("LR_FULL", 2e-5))
WARMUP_FRAC   = float(globals().get("WARMUP_FRAC", 0.03))
SAVE_STEPS    = int(globals().get("SAVE_STEPS", 200))
EVAL_ENABLED  = len(valid) > 0

# 2) reduce activation memory
if hasattr(model, "gradient_checkpointing_enable"):
    model.gradient_checkpointing_enable()
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False  # avoid KV cache during training

# 3) Trainer args — 8-bit AdamW via bitsandbytes
args = TrainingArguments(
    output_dir=str(OUTDIR / "checkpoints"),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=min(BATCH_SIZE, 4),
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR_FULL,
    num_train_epochs=EPOCHS_FULL,
    warmup_ratio=WARMUP_FRAC,
    weight_decay=0.01,
    logging_steps=20,
    save_steps=SAVE_STEPS,
    save_strategy="steps",
    eval_strategy="steps" if EVAL_ENABLED else "no",
    eval_steps=SAVE_STEPS if EVAL_ENABLED else None,
    do_eval=EVAL_ENABLED,
    metric_for_best_model="loss",
    greater_is_better=False,
    load_best_model_at_end=True if EVAL_ENABLED else False,
    report_to=["none"],
    remove_unused_columns=False,
    fp16=False, bf16=False,                # keep FP32 activations (stable)
    optim="adamw_bnb_8bit",                # <<< 8-bit AdamW
)

collator  = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)] if EVAL_ENABLED else []

trainer = Trainer(
    model=model,
    args=args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=valid if EVAL_ENABLED else None,
    callbacks=callbacks,
)

print(f"[train] transformers={transformers.__version__} | optim=adamw_bnb_8bit | "
      f"bsz={BATCH_SIZE} x accum={GRAD_ACCUM} | eval={'on' if EVAL_ENABLED else 'off'}")
res = trainer.train()
print("[train] done. train_loss:", float(res.metrics.get("train_loss", float('nan'))))

# ---- quick neighbors for a few new tokens ----
with torch.no_grad():
    W = model.get_input_embeddings().weight.detach().float().cpu()
    vocab = tok.get_vocab()
    inv_vocab = {i: t for t, i in vocab.items()}

    def neighbors(term, k=6):
        ids = tok(term, add_special_tokens=False)["input_ids"]
        if len(ids) != 1: return []
        tid = ids[0]
        q = F.normalize(W[tid][None, :], dim=-1)
        sims = (q @ F.normalize(W.T, dim=0)).squeeze(0)
        vals, idxs = torch.topk(sims, k+1)
        out = []
        for v, i in zip(vals.tolist(), idxs.tolist()):
            t = inv_vocab.get(i, f"<{i}>")
            if t == term: continue
            out.append((t, round(float(v), 4)))
            if len(out) == k: break
        return out

sample_new = []
try:
    sample_new = random.sample(new_tokens, k=min(8, len(new_tokens)))
except Exception:
    if "synthetic_words" in globals():
        cand = [w["word"] for w in synthetic_words]
        sample_new = [w for w in cand[:100] if len(tok(w, add_special_tokens=False)["input_ids"]) == 1][:8]

print("\n[neighbors] sample new tokens:")
for s in sample_new:
    print(f"  {s:18s} -> {neighbors(s)}")

# ---- save final ----
final_dir = OUTDIR / "final"
final_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(final_dir)
tok.save_pretrained(final_dir)

meta = {
    "epochs": EPOCHS_FULL, "learning_rate": LR_FULL, "warmup_ratio": WARMUP_FRAC,
    "batch_size": BATCH_SIZE, "grad_accum": GRAD_ACCUM,
    "block_size": int(globals().get("BLOCK_SIZE", 256)),
    "train_blocks": len(train), "valid_blocks": len(valid),
    "vocab_size": len(tok),
    "optimizer": "adamw_bnb_8bit",
}
(Path(final_dir / "run_meta.json")).write_text(json.dumps(meta, indent=2))
print("\n[save] wrote final model/tokenizer + metadata to:", final_dir)

[train] transformers=4.57.1 | optim=adamw_bnb_8bit | bsz=2 x accum=4 | eval=on


Step,Training Loss,Validation Loss
200,5.3989,5.425949
400,4.8278,5.230994


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


[train] done. train_loss: 5.492043323516846

[neighbors] sample new tokens:
  requeenment        -> [('queenment', 0.9816), ('requeen', 0.9801), ('requeens', 0.9791), ('queenation', 0.9639), ('enqueen', 0.9606), ('conqueen', 0.9606)]
  pandayly           -> [('codayly', 0.9254), ('exdayly', 0.923), ('dayy', 0.9115), ('dayor', 0.9044), ('dayen', 0.9042), ('redays', 0.9039)]
  rainual            -> [('rainly', 0.9537), ('raines', 0.9533), ('birain', 0.9474), ('exrain', 0.9471), ('trirain', 0.9418), ('conrainer', 0.9399)]
  cothundered        -> [('cothunder', 0.9694), ('cothunderes', 0.9655), ('cothunderal', 0.9567), ('thundered', 0.9545), ('cothunderness', 0.943), ('rethunderly', 0.9213)]
  enwatery           -> [('cowatery', 0.9573), ('wateren', 0.9402), ('inwaterly', 0.9378), ('waterment', 0.9357), ('cowaters', 0.9346), ('rewateres', 0.9344)]
  exsonger           -> [('songer', 0.9731), ('exsongling', 0.9723), ('exsongity', 0.9667), ('exsongness', 0.9667), ('consonger', 0.9664), ('son

post training analysis

In [None]:
# === POST-TRAINING ANALYSIS CELL ===
import json, math, time, os
from pathlib import Path
import torch
import torch.nn.functional as F
import numpy as np

OUTDIR = Path(globals().get("OUTDIR", "runs/wake2vec_run"))
RESULTS_DIR = OUTDIR / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# --- helper: get top-k neighbors for a single token id using CPU-normalized embeddings ---
def _topk_neighbors_for_id(tid, W_norm, inv_vocab, k=20):
    # W_norm: (V, D) float32 CPU tensor already normalized
    q = W_norm[tid : tid + 1]                 # (1, D)
    sims = (q @ W_norm.T).squeeze(0)          # (V,)
    # get top k+5 then filter out the same token etc.
    vals, idxs = torch.topk(sims, k + 5)
    out = []
    for v, i in zip(vals.tolist(), idxs.tolist()):
        token = inv_vocab.get(int(i), f"<{i}>")
        out.append({"token": token, "score": float(v)})
        if len(out) >= k:
            break
    return out

# --- main snapshot function ---
def get_embedding_snapshot(synth_words, model, tokenizer, tag="snapshot", topk=12):
    """
    synth_words: list of dicts with keys 'word','prefix','root','suffix' (as used previously)
    returns a dict with "meta" and "words" map
    """
    t0 = time.time()
    emb = model.get_input_embeddings().weight.detach().float().cpu()   # (V, D)
    V, D = emb.shape
    # normalize once
    W_norm = F.normalize(emb, dim=1)    # (V, D) normalized

    vocab = tokenizer.get_vocab()
    inv_vocab = {i: t for t, i in vocab.items()}

    # helper to get id(s) for text
    def ids_for_text(s):
        return tokenizer(s, add_special_tokens=False)["input_ids"]

    result = {
        "meta": {
            "tag": tag,
            "vocab_size": V,
            "emb_dim": D,
            "time": time.time(),
        },
        "words": {}
    }

    for item in (synth_words or []):
        w = item.get("word")
        if not w:
            continue
        ids = ids_for_text(w)
        # prefer words that are single-token, but if multi-token, average their token vectors
        if len(ids) == 0:
            continue
        vecs = emb[torch.tensor(ids, dtype=torch.long)]
        emb_mean = vecs.mean(dim=0)
        emb_norm = float(torch.norm(emb_mean).item())

        # pick top neighbors using the first token id if single-token; otherwise find nearest by projecting mean
        if len(ids) == 1:
            tid = ids[0]
            neighbors = _topk_neighbors_for_id(tid, W_norm, inv_vocab, k=topk)
        else:
            # fallback: compute cosine of mean vector vs W_norm
            q = F.normalize(emb_mean, dim=0).unsqueeze(0)    # (1, D)
            sims = (q @ W_norm.T).squeeze(0)
            vals, idxs = torch.topk(sims, topk + 5)
            neighbors = []
            for v, i in zip(vals.tolist(), idxs.tolist()):
                token = inv_vocab.get(int(i), f"<{i}>")
                neighbors.append({"token": token, "score": float(v)})
                if len(neighbors) >= topk:
                    break

        # Compose metadata
        composition = {
            "prefix": item.get("prefix"),
            "root": item.get("root"),
            "suffix": item.get("suffix"),
        }

        result["words"][w] = {
            "composition": composition,
            "embedding_norm": emb_norm,
            "tokenized_ids": ids,
            "top_neighbors": neighbors
        }

    result["meta"]["elapsed_sec"] = time.time() - t0
    return result

# --- run snapshot ---
print("\n" + "="*60)
print("POST-TRAINING ANALYSIS")
print("="*60)

# require synthetic_words (but can still run partial analysis if missing)
if "synthetic_words" not in globals():
    print("[warn] synthetic_words not found in globals(); nothing to snapshot.")
    synth = []
else:
    synth = synthetic_words

post_snapshot = get_embedding_snapshot(synth, model, tok, tag="post_morpheme")
out_path = RESULTS_DIR / "post_morpheme_snapshot.json"
out_path.write_text(json.dumps(post_snapshot, indent=2))
print(f"[saved] post snapshot -> {out_path}")

# --- if pre_snapshot exists, compare ---
if "pre_snapshot" in globals() and pre_snapshot and "words" in pre_snapshot:
    print("[compare] pre_snapshot detected — running comparison for sample words")
    comparison = {}
    # pick up to 50 words to compare (or fewer if pre has less)
    sample_words = list(pre_snapshot["words"].keys())[:50]
    for word in sample_words:
        if word not in post_snapshot["words"]:
            continue
        pre = pre_snapshot["words"][word]
        post = post_snapshot["words"][word]

        pre_neighbors = {n["token"] for n in pre.get("top_neighbors", [])[:5]}
        post_neighbors = {n["token"] for n in post.get("top_neighbors", [])[:5]}
        overlap = len(pre_neighbors & post_neighbors)

        comparison[word] = {
            "composition": pre.get("composition"),
            "norm_change": post.get("embedding_norm", 0.0) - pre.get("embedding_norm", 0.0),
            "neighbor_overlap": overlap,
            "pre_top5": [n["token"] for n in pre.get("top_neighbors", [])[:5]],
            "post_top5": [n["token"] for n in post.get("top_neighbors", [])[:5]]
        }

        # pretty print first 10 comparisons
        if len(comparison) <= 10:
            print(f"\n{word} ({comparison[word]['composition']}):")
            print(f"  Norm: {pre.get('embedding_norm',0):.4f} → {post.get('embedding_norm',0):.4f}")
            print(f"  Overlap: {overlap}/5")
            print(f"  Before: {', '.join(comparison[word]['pre_top5'])}")
            print(f"  After:  {', '.join(comparison[word]['post_top5'])}")

    comp_path = RESULTS_DIR / "morpheme_comparison.json"
    comp_path.write_text(json.dumps(comparison, indent=2))
    print(f"\n[written] comparison -> {comp_path}")

else:
    print("[note] pre_snapshot not found; saved only post_snapshot. If you have a pre_snapshot, load it into a var named `pre_snapshot` and re-run this cell to compare.")

# Save final model and tokenizer (tagged)
final_dir = OUTDIR / "final_morpheme_model"
final_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(final_dir)
tok.save_pretrained(final_dir)

print(f"\nModel saved: {final_dir}")
print(f"Results folder: {RESULTS_DIR}")
print("\nMaybe it learned a thing or two — inspect the JSONs for details.")



POST-TRAINING ANALYSIS
[saved] post snapshot -> runs/morpheme_wake_20251030_0017/results/post_morpheme_snapshot.json
[compare] pre_snapshot detected — running comparison for sample words

conmanes (con-+man+-es):
  Norm: 0.1911 → 0.1972
  Overlap: 4/5
  Before: manes, conmaning, enmanes, comanes, mans
  After:  conmanes, manes, conmaning, enmanes, comanes

presounder (pre-+sound+-er):
  Norm: 0.2137 → 0.2209
  Overlap: 4/5
  Before: presounded, ensounder, resound, soundy, resoundly
  After:  presounder, presounded, ensounder, resound, soundy

soundity (Ø+sound+-ity):
  Norm: 0.1989 → 0.2063
  Overlap: 3/5
  Before: soundy, ▁sound, soundal, soundation, soundes
  After:  soundity, soundy, soundal, soundation, soundor

codayor (co-+day+-or):
  Norm: 0.1854 → 0.1925
  Overlap: 4/5
  Before: codayy, codayed, dayor, codayly, codayness
  After:  codayor, codayy, codayed, dayor, codayly

outdanceist (out-+dance+-ist):
  Norm: 0.2437 → 0.2499
  Overlap: 4/5
  Before: dancely, danceed, danceor,

In [31]:
# === Summary stats + top movers ===
import json, math, statistics
from pathlib import Path

OUTDIR = Path("runs/morpheme_wake_20251030_0017")
RESULTS = OUTDIR / "results"

pre = globals().get("pre_snapshot", None)
post = globals().get("post_snapshot", None)
# if not in-memory, try to load saved
if pre is None:
    ppath = RESULTS / "pre_morpheme_snapshot.json"
    if ppath.exists():
        pre = json.loads(ppath.read_text(encoding="utf-8"))
if post is None:
    ppath = RESULTS / "post_morpheme_snapshot.json"
    if ppath.exists():
        post = json.loads(ppath.read_text(encoding="utf-8"))

if not post:
    raise RuntimeError("post_snapshot missing (need post_morpheme_snapshot.json or post_snapshot in memory).")

words = list(post["words"].keys())
# build arrays of stats where pre exists
norm_changes = []
overlaps = []
records = {}
for w in words:
    post_w = post["words"][w]
    pre_w = (pre["words"].get(w) if pre and "words" in pre else None)
    if pre_w:
        pre_norm = pre_w.get("embedding_norm")
        post_norm = post_w.get("embedding_norm")
        if pre_norm is not None and post_norm is not None:
            nc = post_norm - pre_norm
            norm_changes.append(nc)
        # overlap top5
        pre_top5 = {n["token"] for n in pre_w.get("top_neighbors", [])[:5]}
        post_top5 = {n["token"] for n in post_w.get("top_neighbors", [])[:5]}
        overlap = len(pre_top5 & post_top5)
        overlaps.append(overlap)
        records[w] = {
            "pre_norm": pre_w.get("embedding_norm"),
            "post_norm": post_w.get("embedding_norm"),
            "norm_change": post_w.get("embedding_norm") - pre_w.get("embedding_norm"),
            "pre_top5": [n["token"] for n in pre_w.get("top_neighbors",[])[:5]],
            "post_top5": [n["token"] for n in post_w.get("top_neighbors",[])[:5]],
            "overlap_top5": overlap
        }

print("N compared words:", len(records))
if norm_changes:
    print("Norm change: mean=", statistics.mean(norm_changes), "median=", statistics.median(norm_changes),
          "min=", min(norm_changes), "max=", max(norm_changes))
if overlaps:
    print("Overlap top5: mean=", statistics.mean(overlaps), "median=", statistics.median(overlaps))

# Top movers by absolute norm change and by neighbor-change (1 - overlap)
top_by_norm = sorted(records.items(), key=lambda x: abs(x[1]['norm_change'] or 0), reverse=True)[:30]
top_by_overlap_loss = sorted(records.items(), key=lambda x: (5 - (x[1]['overlap_top5'] or 0)), reverse=True)[:30]

# Save
(Path(RESULTS / "summary_stats.json")).write_text(json.dumps({
    "n_compared": len(records),
    "norm_change_stats": {
        "mean": statistics.mean(norm_changes) if norm_changes else None,
        "median": statistics.median(norm_changes) if norm_changes else None,
    },
    "overlap_mean": statistics.mean(overlaps) if overlaps else None
}, indent=2))

print("\nTop 10 by absolute norm change:")
for w,v in top_by_norm[:10]:
    print(f" {w:20s} norm Δ={v['norm_change']:.4f} overlap={v['overlap_top5']}  pre_top5={v['pre_top5'][:5]}")

print("\nTop 10 by overlap loss (moved away):")
for w,v in top_by_overlap_loss[:10]:
    print(f" {w:20s} overlap={v['overlap_top5']} norm Δ={v['norm_change']:.4f}")


N compared words: 49
Norm change: mean= 0.005135101273595071 median= 0.006620781826972955 min= -0.07503086223602295 max= 0.01022875852584837
Overlap top5: mean= 3.7142857142857144 median= 4

Top 10 by absolute norm change:
 sound                norm Δ=-0.0750 overlap=2  pre_top5=['▁sound', 'soundation', 'soundity', 'resound', 'soundy']
 cloud                norm Δ=0.0102 overlap=4  pre_top5=['Cloud', 'clouded', '▁cloud', 'cloudly', 'clouder']
 probook              norm Δ=0.0099 overlap=4  pre_top5=['▁book', 'bookor', 'booker', 'bookes', 'booking']
 enstone              norm Δ=0.0090 overlap=4  pre_top5=['enstoneer', '▁stone', 'stonees', 'stoneed', 'stoneer']
 thundered            norm Δ=0.0089 overlap=4  pre_top5=['thunders', 'cothundered', 'dethunder', 'unthunder', 'cothunder']
 revoicely            norm Δ=0.0085 overlap=4  pre_top5=['revoice', 'voicely', 'revoiceing', 'revoiceal', 'voiceed']
 destar               norm Δ=0.0085 overlap=4  pre_top5=['destares', 'destaring', 'destarment

In [32]:
# === Plots: histograms + scatter ===
import matplotlib.pyplot as plt
import numpy as np

RESULTS.mkdir(parents=True, exist_ok=True)

# gather arrays
norms = [r["norm_change"] for r in records.values() if r["norm_change"] is not None]
ov = [r["overlap_top5"] for r in records.values() if r["overlap_top5"] is not None]

plt.figure()
plt.hist(ov, bins=range(0,6), align='left')
plt.title("Neighbor overlap (top5) distribution")
plt.xlabel("Overlap (0..5)")
plt.ylabel("Count")
plt.savefig(RESULTS / "hist_overlap_top5.png")
plt.close()

plt.figure()
plt.hist(norms, bins=50)
plt.title("Embedding norm change distribution (post - pre)")
plt.xlabel("Norm change")
plt.ylabel("Count")
plt.savefig(RESULTS / "hist_norm_change.png")
plt.close()

# scatter
x = norms
y = ov[:len(x)]
plt.figure(figsize=(6,4))
plt.scatter(x, y, alpha=0.6)
plt.title("Norm change vs overlap (top5)")
plt.xlabel("Norm change")
plt.ylabel("Overlap (top5)")
plt.tight_layout()
plt.savefig(RESULTS / "scatter_norm_vs_overlap.png")
plt.close()

print("Saved plots to", RESULTS)


Saved plots to runs/morpheme_wake_20251030_0017/results


In [33]:
# t-SNE / UMAP visualization
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# optional: from umap import UMAP   # if you prefer UMAP and it's installed

# pick sample set (all new tokens in records)
tokens = list(records.keys())
# build vectors: use post embeddings from model
import torch
post_emb = model.get_input_embeddings().weight.detach().cpu()

def avg_vec_for_token(t):
    ids = tok(t, add_special_tokens=False)["input_ids"]
    if not ids:
        return None
    return post_emb[torch.tensor(ids)].mean(dim=0).numpy()

vecs = []
labels = []
for t in tokens:
    v = avg_vec_for_token(t)
    if v is None: continue
    vecs.append(v)
    labels.append(t)

# also compute centroids of pre_top5 neighbours (if available) using post-emb to compare positions
centroids = []
cent_labels = []
for t in tokens:
    pre_top5 = records[t].get("pre_top5", [])
    neigh_vecs = []
    for nt in pre_top5:
        v = avg_vec_for_token(nt)
        if v is not None:
            neigh_vecs.append(v)
    if neigh_vecs:
        centroids.append(np.mean(neigh_vecs, axis=0))
        cent_labels.append(t + "_precentroid")

# stack and run t-SNE
all_vecs = np.vstack(vecs + centroids) if centroids else np.vstack(vecs)
ts = TSNE(n_components=2, perplexity=30, init='random', random_state=42, n_iter=1000)
proj = ts.fit_transform(all_vecs)

n_tok = len(vecs)
proj_tokens = proj[:n_tok]
proj_cents = proj[n_tok:] if centroids else []

plt.figure(figsize=(10,10))
plt.scatter(proj_tokens[:,0], proj_tokens[:,1], s=25, c='C0', label='new_tokens')
if len(proj_cents):
    plt.scatter(proj_cents[:,0], proj_cents[:,1], s=60, marker='x', c='C1', label='pre_top5_centroids')
# annotate top extreme movers for quick inspection
for i, t in enumerate(labels):
    if abs(records[t]["norm_change"]) > np.percentile(np.abs(norms), 90) or records[t]["overlap_top5"] <= 2:
        plt.text(proj_tokens[i,0], proj_tokens[i,1], t, fontsize=8)
plt.legend()
plt.title("t-SNE: new tokens (post) and pre-top5 centroids")
plt.savefig(RESULTS / "tsne_newtokens_vs_precentroids.png", dpi=150)
plt.close()
print("Saved t-SNE plot:", RESULTS / "tsne_newtokens_vs_precentroids.png")




Saved t-SNE plot: runs/morpheme_wake_20251030_0017/results/tsne_newtokens_vs_precentroids.png


In [34]:
import math, torch
from transformers import Trainer, TrainingArguments

if 'valid' not in globals() or len(valid) == 0:
    print("No valid dataset found (valid variable missing or empty). Skipping perplexity.")
else:
    # make a lightweight evaluation loop using Trainer
    args_eval = TrainingArguments(output_dir=str(OUTDIR / "tmp_eval"), per_device_eval_batch_size=4, report_to=["none"])
    trainer_eval = Trainer(model=model, args=args_eval)
    res = trainer_eval.evaluate(eval_dataset=valid)
    eval_loss = res.get("eval_loss", res.get("loss", None))
    if eval_loss is not None:
        perp = math.exp(eval_loss)
        print(f"Validation loss: {eval_loss:.4f} -> Perplexity: {perp:.2f}")
    else:
        print("Could not find eval_loss in Trainer output:", res)

Validation loss: 5.2310 -> Perplexity: 186.98


In [35]:
# generation samples for selected tokens
import random, torch
def gen_seed(token, max_new=40):
    prompt = token + " "
    ids = tok(prompt, return_tensors="pt").to(next(model.parameters()).device)
    out = model.generate(**ids, max_new_tokens=max_new, do_sample=True, top_p=0.92, temperature=0.9)
    return tok.decode(out[0], skip_special_tokens=True)

# pick tokens
top_movers = [w for w,_ in top_by_norm[:10]]
low_movers = [w for w,_ in top_by_norm[-10:]]
sample_set = (top_movers[:5] + low_movers[:5])[:10]
print("Generating for sample tokens:", sample_set)
for t in sample_set:
    try:
        print("\n----", t)
        print(gen_seed(t, max_new=40))
    except Exception as e:
        print("Generation failed for", t, e)


Generating for sample tokens: ['sound', 'cloud', 'probook', 'enstone', 'thundered', 'danceed', 'cowordes', 'codayor', 'consongen', 'rewalkum']

---- sound
sound . Douth! It's as it is, the wold. Tip. Oclinate.

[1] A covenity of ballybills and b

---- cloud
cloud   O'morn
is the hodle! Ogone!

  Mick

  Jute.— A jute of pantry is a jacper.


---- probook
probook  2683), the pigeon's pudding for an man's
                man; he was a young man's man (I am a rede) and

---- enstone
enstone   Higd.

[1] My huntswoman, my boys, my hubbings, my tilt on my tinkle.

[2

---- thundered
thundered   hunt!
dill and ove. And all the fungos, the lamus, the flamim, the
chillons, the lordies, the

---- danceed
danceed 3822. Tailu!

—Peggiery, sir, he was aloft as I looked to it, as I thought,
when he is the

---- cowordes
cowordes 31. Netty Bawls, the binnies, we to'll say. The old
telltown, the first of the jews, the one of

---- codayor
codayor              Ondt?
tip and hug! Tip. From the tumbum's

In [None]:
# sanity
import torch
emb = model.get_input_embeddings().weight.detach()
print("Embedding shape:", emb.shape, "mean:", float(emb.mean()), "std:", float(emb.std()))
if hasattr(model, "lm_head"):
    tied = (model.lm_head.weight.data_ptr() == emb.data_ptr())
    print("lm_head tied to embeddings?", tied)
else:
    print("No lm_head attribute on model object.")
# token frequency check: how often new tokens occur in dataset (rough estimate)
def token_freq_estimate(token, dataset, tokenizer):
    # sample some blocks and count occurrences
    import random
    sample_idxs = random.sample(range(len(dataset)), min(200, len(dataset)))
    cnt = 0
    for i in sample_idxs:
        toks = dataset[i]["input_ids"]
        # approximate: decode block and count token string occurrences
        s = tokenizer.decode(toks, skip_special_tokens=True)
        cnt += s.count(token)
    return cnt / len(sample_idxs)
print("Estimated freq of a few tokens in train (sampled blocks):")
for t in sample_set[:6]:
    print(" ", t, token_freq_estimate(t, train, tok) if 'train' in globals() else "no train ds")
