<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/token_injection_%26_training_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2vec Token Injection & Training


In [49]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


setup

In [50]:
# If Colab complains about versions, uncomment and run once.
%pip -q install --upgrade transformers datasets accelerate sentencepiece

In [51]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = "1337"

import torch
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.use_deterministic_algorithms(True)

import random, json, math, re
import numpy as np
from datetime import datetime
from pathlib import Path
from collections import defaultdict

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("gon fry your embedding layer")

gon fry your embedding layer


config

In [53]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CORPUS_PATH = "/content/drive/MyDrive/Wake2vec_runs/fw.txt"
WAKE_LEXICON = "/content/drive/MyDrive/Wake2vec_runs/wake_lexicon.txt"

BATCH_SIZE = 2
BLOCK_SIZE = 256
EPOCHS = 2
LR = 2e-5
WARMUP_RATIO = 0.05
WEIGHT_DECAY = 0.01
GRAD_ACCUM = 4
SAVE_STEPS = 200
MAX_NEW_TERMS = 128
AUGMENT_EACH = 16  # More synthetic data

# PARAMS
PORTMANTEAU_ALPHA = 0.5      # Blend strength for portmanteaus
ROTATION_STRENGTH = 0.01     # Embedding drift
MULTILINGUAL_STRENGTH = 0.1  # Cluster tightness
NOISE_AMPLITUDE = 0.05       # Wake static
PHONETIC_WEIGHT = 1.5        # Alliteration bonus
EMBEDDING_TEMP_MAX = 2.0     # Annealing max temp

# Output
RUN_ID = datetime.now().strftime("wake2vec_chaos_%Y%m%d_%H%M")
OUTDIR = Path(f"./runs/{RUN_ID}")
(OUTDIR / "results").mkdir(parents=True, exist_ok=True)
(OUTDIR / "checkpoints").mkdir(parents=True, exist_ok=True)

print(f"Run ID: {RUN_ID}")

Run ID: wake2vec_chaos_20251029_2248


load corpus

In [55]:
def load_corpus(path):
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Corpus not found: {p}")
    text = p.read_text(encoding="utf-8", errors="ignore")
    print(f"Loaded corpus: {len(text)} chars")
    return text

FW_TEXT = load_corpus(CORPUS_PATH)

Loaded corpus: 1364712 chars


load model and tok

In [56]:
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

set_seed(SEED)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

print(f"Model: {MODEL_NAME}")
print(f"Device: {DEVICE}")
print(f"Initial vocab: {len(tok)}")

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Device: cuda
Initial vocab: 32000


load lex

In [57]:
def load_lexicon(path):
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Lexicon not found: {p}")

    terms = []
    with p.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            s = line.strip()
            if s:
                terms.append(s)

    seen = set()
    unique_terms = []
    for t in terms:
        if t not in seen:
            seen.add(t)
            unique_terms.append(t)

    print(f"Loaded lexicon: {len(unique_terms)} terms")
    return unique_terms

all_terms = load_lexicon(WAKE_LEXICON)

def is_multitoken(term, tokenizer):
    ids = tokenizer(term, add_special_tokens=False)["input_ids"]
    return len(ids) > 1

NEW_TERMS = [t for t in all_terms if is_multitoken(t, tok)][:MAX_NEW_TERMS]
print(f"elected {len(NEW_TERMS)} multi-token terms")

Loaded lexicon: 44990 terms
elected 128 multi-token terms


portmanteaus & setup

In [59]:
def detect_portmanteaus(terms, vocab, tokenizer):
    """Find likely portmanteaus (compound words in Wake)"""
    portmanteaus = []
    vocab_set = set(vocab.keys())

    # Known Joyce portmanteaus to bootstrap
    known_pairs = [
        ("river", "run", "riverrun"),
        ("chaos", "cosmos", "chaosmos"),
        ("thunder", "words", "thunderwords"),
        ("Anna", "Livia", "AnnaLivia"),
        ("Phoenix", "Park", "PhoenixPark"),
    ]

    for w1, w2, port in known_pairs:
        # Check if portmanteau is in our terms
        for term in terms:
            if port.lower() in term.lower():
                # Check if component words exist in vocab
                id1 = tokenizer.convert_tokens_to_ids(w1)
                id2 = tokenizer.convert_tokens_to_ids(w2)
                if id1 != tokenizer.unk_token_id and id2 != tokenizer.unk_token_id:
                    portmanteaus.append((w1, w2, term))

    # Try detecting from terms themselves
    for term in terms[:100]:
        term_clean = term.strip("▁").lower()

        # Try common English words as components
        common_words = ["river", "run", "water", "fire", "time", "word", "wake",
                       "night", "day", "thunder", "light", "dark", "man", "woman",
                       "sun", "moon", "star", "earth", "sea", "sky"]

        for word in common_words:
            if word in term_clean and len(word) < len(term_clean):
                # Check if word is in vocab
                wid = tokenizer.convert_tokens_to_ids(word)
                if wid != tokenizer.unk_token_id:
                    # Try to find another component
                    idx = term_clean.find(word)
                    if idx == 0:
                        # word is at start, check suffix
                        suffix = term_clean[len(word):]
                        if len(suffix) >= 3:
                            # Try suffix as word
                            sid = tokenizer.convert_tokens_to_ids(suffix)
                            if sid != tokenizer.unk_token_id:
                                portmanteaus.append((word, suffix, term))
                                break
                    elif idx > 0:
                        # word is in middle/end, check prefix
                        prefix = term_clean[:idx]
                        if len(prefix) >= 3:
                            pid = tokenizer.convert_tokens_to_ids(prefix)
                            if pid != tokenizer.unk_token_id:
                                portmanteaus.append((prefix, word, term))
                                break

    # Dedupe
    seen = set()
    unique_ports = []
    for p in portmanteaus:
        key = (p[0], p[1], p[2])
        if key not in seen:
            seen.add(key)
            unique_ports.append(p)

    # If still empty, create synthetic portmanteaus
    if len(unique_ports) == 0:
        print("⚠ No portmanteaus detected, creating synthetic blends...")
        synthetic_pairs = [
            ("thunder", "word", "thunderword"),
            ("river", "flow", "riverflow"),
            ("night", "fall", "nightfall"),
            ("sun", "rise", "sunrise"),
            ("moon", "light", "moonlight"),
        ]

        for w1, w2, blend in synthetic_pairs:
            id1 = tokenizer.convert_tokens_to_ids(w1)
            id2 = tokenizer.convert_tokens_to_ids(w2)
            if id1 != tokenizer.unk_token_id and id2 != tokenizer.unk_token_id:
                unique_ports.append((w1, w2, blend))

    print(f"Detected/created {len(unique_ports)} portmanteaus")
    return unique_ports

vocab = tok.get_vocab()
PORTMANTEAUS = detect_portmanteaus(NEW_TERMS, vocab, tok)
print(f"Examples: {PORTMANTEAUS[:5]}")

⚠ No portmanteaus detected, creating synthetic blends...
Detected/created 2 portmanteaus
Examples: [('river', 'flow', 'riverflow'), ('night', 'fall', 'nightfall')]


basic multi-ling

In [60]:
MULTILINGUAL_CLUSTERS = {
    "river": ["river", "rivière", "Fluss", "fiume", "río"],
    "thunder": ["thunder", "tonnerre", "Donner", "tuono", "trueno"],
    "night": ["night", "nuit", "Nacht", "notte", "noche"],
    "wake": ["wake", "éveil", "Wache", "veglia", "despertar"],
    "word": ["word", "mot", "Wort", "parola", "palabra"],
    "time": ["time", "temps", "Zeit", "tempo", "tiempo"],
    "water": ["water", "eau", "Wasser", "acqua", "agua"],
    "fire": ["fire", "feu", "Feuer", "fuoco", "fuego"],}

token injection

In [61]:
def inject_tokens_chaotic(terms, model, tokenizer, portmanteaus):
    """Inject with portmanteau-aware initialization"""
    to_add = []
    for t in terms:
        variants = [t, f"▁{t}"] if not t.startswith("▁") else [t]
        for form in variants:
            tid = tokenizer.convert_tokens_to_ids(form)
            if tid == tokenizer.unk_token_id:
                to_add.append(form)

    if not to_add:
        print("No new tokens to add")
        return []

    old_size = model.get_input_embeddings().num_embeddings
    n_added = tokenizer.add_tokens(to_add, special_tokens=False)
    model.resize_token_embeddings(len(tokenizer))

    emb = model.get_input_embeddings().weight.data
    existing_std = emb[:old_size].std().item()

    # Initialize with random normal + noise
    new_emb = torch.randn(n_added, emb.shape[1], device=emb.device) * existing_std
    noise = torch.randn_like(new_emb) * (existing_std * 0.1)
    emb[old_size:old_size + n_added] = new_emb + noise

    # PORTMANTEAU LERPING
    print("Applying portmanteau blending...")
    for w1, w2, port in portmanteaus:
        id1 = tokenizer.convert_tokens_to_ids(w1)
        id2 = tokenizer.convert_tokens_to_ids(w2)
        port_id = tokenizer.convert_tokens_to_ids(port)

        if (port_id != tokenizer.unk_token_id and
            id1 != tokenizer.unk_token_id and
            id2 != tokenizer.unk_token_id and
            port_id >= old_size):  # Only for newly added tokens

            alpha = PORTMANTEAU_ALPHA
            emb[port_id] = alpha * emb[id1] + (1-alpha) * emb[id2]
            # Add chaos noise
            emb[port_id] += torch.randn_like(emb[port_id]) * existing_std * 0.05

    if hasattr(model, "tie_weights"):
        model.tie_weights()

    print(f"Injected {n_added} tokens (vocab: {old_size} → {len(tokenizer)})")
    return to_add

added = inject_tokens_chaotic(NEW_TERMS, model, tok, PORTMANTEAUS)

The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Applying portmanteau blending...
Injected 256 tokens (vocab: 32000 → 32256)


data aug

In [None]:
def generate_wake_mashup(terms, n_samples=500):
    """Generate maximum chaos synthetic Wake text"""
    patterns = [
        "The {t1}ing {t2} of {t3}",
        "By {t1} and by {t2} and by {t3}",
        "In the {t1} of his {t2} was {t3}",
        "From {t1} to {t2}, a {t3} of words",
        "… and {t1} rolled back to {t2} in a {t3}.",
        "He spoke of {t1} as if {t2} kept its {t3}.",
        "Between {t1} and {t2}, the {t3} broke in two.",
        "Call it {t1}, call it {t2} and {t3}.",
        "Through {t1} the {t2} wound their way to {t3}.",
        "{t1} {t2} {t3} again in the chaosmos.",
        "Riverrun past {t1} and {t2} from {t3}.",
        "The {t1} of {t2} brings us back to {t3}.",
        "{t1}! {t2}! O {t3}!",
        "In the beginning was the {t1}, the {t2}, the {t3}.",
        "{t1} {t2} {t3} thunderwords.",
    ]

    # Add recursive puns
    pun_patterns = [
        "{t1} the {t1}er {t1}ing",
        "{t1} un{t1} re{t1}",
        "the {t1} of {t1}s",
    ]
    patterns.extend(pun_patterns)

    examples = []
    for _ in range(n_samples):
        if random.random() < 0.3:
            # Single term recursion
            t = random.choice(terms)
            pattern = random.choice(pun_patterns)
            examples.append(pattern.format(t1=t))
        else:
            # Multi-term mashup
            shuffled = random.sample(terms, min(3, len(terms)))
            pattern = random.choice(patterns)
            examples.append(pattern.format(
                t1=shuffled[0],
                t2=shuffled[1] if len(shuffled) > 1 else shuffled[0],
                t3=shuffled[2] if len(shuffled) > 2 else shuffled[0]
            ))

    random.shuffle(examples)
    return examples

def create_anagrams(terms, n_per_term=2):
    """Create anagram-based synthetic examples"""
    examples = []
    for term in terms[:30]:
        chars = list(term)
        for _ in range(n_per_term):
            random.shuffle(chars)
            anagram = ''.join(chars)
            examples.append(f"The {term} became {anagram} in the telling.")
    return examples

# Generate synthetic chaos
synthetic = generate_wake_mashup(NEW_TERMS, n_samples=len(NEW_TERMS) * AUGMENT_EACH)
anagrams = create_anagrams(NEW_TERMS, n_per_term=2)
all_synthetic = synthetic + anagrams
random.shuffle(all_synthetic)

COMBINED_TEXT = FW_TEXT + "\n" + "\n".join(all_synthetic)

print(f"✓ Generated {len(all_synthetic)} synthetic chaos examples")
print(f"✓ Combined corpus: {len(COMBINED_TEXT)} chars")

# Sample
print(f"\nSample chaos:")
for s in all_synthetic[:5]:
    print(f"  {s}")

dataset prep

In [1]:
from datasets import Dataset

def create_blocks(text, tokenizer, block_size):
    ids = tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
    n_blocks = len(ids) // block_size
    if n_blocks == 0:
        raise ValueError(f"Text too short for block_size={block_size}")
    ids = ids[:n_blocks * block_size]
    arr = np.array(ids, dtype=np.int32).reshape(n_blocks, block_size)
    return Dataset.from_dict({"input_ids": arr.tolist()})

ds = create_blocks(COMBINED_TEXT, tok, BLOCK_SIZE)
print(f"✓ Created {len(ds)} blocks")

ds = ds.map(lambda x: {"labels": x["input_ids"]}, batched=True)

n = len(ds)
if n > 20:
    split_idx = int(n * 0.9)
    train_ds = ds.select(range(split_idx))
    valid = ds.select(range(split_idx, n))
else:
    train_ds = ds
    valid = None

print(f"Train: {len(train_ds)}, Val: {len(valid) if valid else 0}")

KeyboardInterrupt: 

extra mods

In [None]:
class WakeNoiseInjector:
    """Inject structured chaos into embeddings"""
    def __init__(self, amplitude=NOISE_AMPLITUDE, schedule="cyclical"):
        self.amplitude = amplitude
        self.schedule = schedule
        self.step = 0

    def __call__(self, embeddings):
        self.step += 1
        if self.schedule == "cyclical":
            noise_scale = self.amplitude * abs(math.sin(self.step / 100))
        else:
            noise_scale = self.amplitude

        noise = torch.randn_like(embeddings) * noise_scale
        return embeddings + noise

class RotationLoss(torch.nn.Module):
    """Encourage embeddings to drift through semantic space"""
    def __init__(self, strength=ROTATION_STRENGTH):
        super().__init__()
        self.strength = strength
        self.rotation_matrix = None

    def forward(self, embeddings, term_ids):
        if len(term_ids) == 0:
            return torch.tensor(0.0, device=embeddings.device)

        wake_embs = embeddings[term_ids]

        # Create/reuse rotation matrix
        if self.rotation_matrix is None or self.rotation_matrix.shape[0] != embeddings.shape[1]:
            rot = torch.randn(embeddings.shape[1], embeddings.shape[1], device=embeddings.device)
            self.rotation_matrix = torch.linalg.qr(rot)[0]

        rotated = wake_embs @ self.rotation_matrix
        # Encourage small rotation away from current position
        loss = -self.strength * torch.mean((wake_embs - rotated) ** 2)
        return loss

class MultilingualClusteringLoss(torch.nn.Module):
    """Force multilingual cognates to cluster"""
    def __init__(self, clusters, tokenizer, strength=MULTILINGUAL_STRENGTH):
        super().__init__()
        self.strength = strength
        self.cluster_ids = []

        for cluster_words in clusters.values():
            ids = []
            for word in cluster_words:
                tid = tokenizer.convert_tokens_to_ids(word)
                if tid != tokenizer.unk_token_id:
                    ids.append(tid)
            if len(ids) > 1:
                self.cluster_ids.append(ids)

        print(f"✓ Multilingual clusters: {len(self.cluster_ids)}")

    def forward(self, embeddings):
        if len(self.cluster_ids) == 0:
            return torch.tensor(0.0, device=embeddings.device)

        total_loss = 0
        for ids in self.cluster_ids:
            cluster_embs = embeddings[ids]
            centroid = cluster_embs.mean(dim=0, keepdim=True)
            dist = torch.sum((cluster_embs - centroid) ** 2)
            total_loss += dist

        return self.strength * total_loss / len(self.cluster_ids)

class PhoneticWeighting:
    """Weight loss by phonetic similarity (alliteration)"""
    def __init__(self, weight=PHONETIC_WEIGHT):
        self.weight = weight

    def phonetic_similarity(self, w1, w2):
        """Rough phonetic similarity via first 2 chars"""
        if len(w1) < 2 or len(w2) < 2:
            return 0
        return int(w1[:2].lower() == w2[:2].lower())

    def get_weight(self, token_ids, tokenizer):
        """Compute phonetic weight for a sequence"""
        tokens = [tokenizer.convert_ids_to_tokens(int(tid)) for tid in token_ids]
        weights = torch.ones(len(tokens))

        for i in range(len(tokens) - 1):
            if self.phonetic_similarity(tokens[i], tokens[i+1]):
                weights[i] *= self.weight
                weights[i+1] *= self.weight

        return weights

extra

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
import inspect

# Get IDs of injected terms
injected_ids = []
for term in NEW_TERMS:
    tid = tok.convert_tokens_to_ids(term)
    if tid != tok.unk_token_id:
        injected_ids.append(tid)

injected_ids = torch.tensor(injected_ids, device=DEVICE)

# Initialize chaos modules
noise_injector = WakeNoiseInjector()
rotation_loss = RotationLoss().to(DEVICE)
multilingual_loss = MultilingualClusteringLoss(MULTILINGUAL_CLUSTERS, tok).to(DEVICE)
phonetic_weighter = PhoneticWeighting()

class ChaosTrainer(Trainer):
    """Trainer with maximum Joycean chaos"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.chaos_step = 0

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Get base loss
        outputs = model(**inputs)
        base_loss = outputs.loss

        # Get embeddings
        emb = model.get_input_embeddings().weight

        # Apply chaos losses
        chaos_loss = 0

        # 1. Rotation loss (embedding drift)
        if len(injected_ids) > 0:
            rot_loss = rotation_loss(emb, injected_ids)
            chaos_loss += rot_loss

        # 2. Multilingual clustering
        multi_loss = multilingual_loss(emb)
        chaos_loss += multi_loss

        # 3. Embedding noise (apply during forward)
        # (This modifies embeddings in-place for this step only)

        # Total loss
        total_loss = base_loss + chaos_loss

        self.chaos_step += 1

        # Log extra metrics periodically
        if self.chaos_step % 50 == 0:
            self.log({
                "chaos/rotation": rot_loss.item() if len(injected_ids) > 0 else 0,
                "chaos/multilingual": multi_loss.item(),
                "chaos/total": chaos_loss.item(),
            })

        return (total_loss, outputs) if return_outputs else total_loss

training set up

In [2]:
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
has_eval = valid is not None and len(valid) > 0

# Check supported args
sig = inspect.signature(TrainingArguments.__init__)
supported = sig.parameters.keys()

args_dict = {
    "output_dir": str(OUTDIR / "checkpoints"),
    "num_train_epochs": EPOCHS,
    "per_device_train_batch_size": BATCH_SIZE,
    "gradient_accumulation_steps": GRAD_ACCUM,
    "learning_rate": LR,
    "weight_decay": WEIGHT_DECAY,
    "warmup_ratio": WARMUP_RATIO,
    "logging_steps": 20,
    "fp16": False,
    "seed": SEED,
}

if "save_strategy" in supported:
    args_dict.update({"save_strategy": "steps", "save_steps": SAVE_STEPS, "save_total_limit": 2})
else:
    args_dict["save_steps"] = SAVE_STEPS

if "evaluation_strategy" in supported:
    args_dict["evaluation_strategy"] = "steps" if has_eval else "no"
    if has_eval: args_dict["eval_steps"] = SAVE_STEPS
elif "evaluate_during_training" in supported:
    args_dict["evaluate_during_training"] = has_eval
    if has_eval and "eval_steps" in supported: args_dict["eval_steps"] = SAVE_STEPS

if "report_to" in supported: args_dict["report_to"] = ["none"]
if "bf16" in supported: args_dict["bf16"] = False
if "remove_unused_columns" in supported: args_dict["remove_unused_columns"] = False
if "lr_scheduler_type" in supported: args_dict["lr_scheduler_type"] = "cosine"
if "per_device_eval_batch_size" in supported: args_dict["per_device_eval_batch_size"] = BATCH_SIZE

training_args = TrainingArguments(**args_dict)

trainer = ChaosTrainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train_ds,
    eval_dataset=valid if has_eval else None,
)

print(f"Chaos Trainer ready")

NameError: name 'DataCollatorForLanguageModeling' is not defined

pre train snap-shot

In [47]:
print("\n" + "="*60)
print("TAKING PRE-TRAINING SNAPSHOT")
print("="*60 + "\n")

def get_embedding_snapshot(terms, model, tokenizer, name="snapshot"):
    """Capture embedding state for comparison"""
    model.eval()
    emb_matrix = model.get_input_embeddings().weight.data
    emb_norm = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)

    snapshot = {
        "name": name,
        "vocab_size": len(tokenizer),
        "embedding_dim": emb_matrix.shape[1],
        "terms": {}
    }

    for term in terms:
        tid = tokenizer.convert_tokens_to_ids(term)
        if tid == tokenizer.unk_token_id:
            continue

        # Get embedding vector (as list for JSON serialization)
        term_emb_norm = emb_norm[tid]

        # Compute similarities with all tokens
        sims = torch.matmul(term_emb_norm.unsqueeze(0), emb_norm.T)[0]

        # Get top 10 neighbors
        top_k = torch.topk(sims, 11)  # 11 to exclude self

        neighbors = []
        for idx, sim in zip(top_k.indices[1:], top_k.values[1:]):  # Skip first (self)
            neighbor_token = tokenizer.convert_ids_to_tokens(idx.item())
            neighbors.append({
                "token": neighbor_token,
                "sim": round(sim.item(), 4)
            })

        # Store term info
        snapshot["terms"][term] = {
            "token_id": tid,
            "embedding_norm": round(emb_matrix[tid].norm().item(), 4),
            "top_neighbors": neighbors[:10]
        }

    return snapshot

# Take snapshot of injected terms
pre_snapshot = get_embedding_snapshot(NEW_TERMS[:50], model, tok, name="pre_training")

# Save snapshot
snapshot_path = OUTDIR / "results" / "pre_training_snapshot.json"
with open(snapshot_path, "w") as f:
    json.dump(pre_snapshot, f, indent=2)

print(f"Pre-training snapshot saved: {snapshot_path}")
print(f"Captured {len(pre_snapshot['terms'])} terms")

# Display sample
print("\nSample pre-training neighbors:")
for term, data in list(pre_snapshot["terms"].items())[:3]:
    print(f"\n{term} (norm: {data['embedding_norm']}):")
    for n in data["top_neighbors"][:5]:
        print(f"  {n['token']}: {n['sim']}")

print("\n" + "="*60)


TAKING PRE-TRAINING SNAPSHOT

Pre-training snapshot saved: runs/wake2vec_20251029_2226/results/pre_training_snapshot.json
Captured 50 terms

Sample pre-training neighbors:

paùpulation (norm: 0.0502):
  ĠWash: 1.0
  Ġaccomplished: 1.0
  paùpulation: 1.0
  Ġanchored: 1.0
  ×ķ: 1.0

générations (norm: 0.0116):
  Ġtal: 1.0
  générations: 1.0
  ĠProto: 1.0
  ĠSl: 1.0
  021: 1.0

introdùce (norm: 0.0469):
  ĠKand: 1.0
  introdùce: 1.0
  ĠDATA: 1.0
  Ġintro: 1.0
  Ġadolescence: 1.0



train

In [None]:
print("\n" + "="*60)
print("UNLEASHING CHAOS")
print("="*60 + "\n")

result = trainer.train()

print("\n" + "="*60)
print("CHAOS COMPLETE")
print("="*60)
print(f"Final loss: {result.metrics.get('train_loss', 'N/A'):.4f}")

analysis

In [None]:
print("\n" + "="*60)
print("POST-CHAOS ANALYSIS")
print("="*60)

post_snapshot = get_embedding_snapshot(NEW_TERMS[:50], model, tok, "post_chaos")
post_path = OUTDIR / "results" / "post_chaos_snapshot.json"
with open(post_path, "w") as f:
    json.dump(post_snapshot, f, indent=2)

# Compare
comparison = {}
for term in list(pre_snapshot["terms"].keys())[:10]:
    if term not in post_snapshot["terms"]:
        continue

    pre = pre_snapshot["terms"][term]
    post = post_snapshot["terms"][term]

    pre_neighbors = {n["token"] for n in pre["top_neighbors"][:5]}
    post_neighbors = {n["token"] for n in post["top_neighbors"][:5]}
    overlap = len(pre_neighbors & post_neighbors)

    comparison[term] = {
        "norm_change": post["embedding_norm"] - pre["embedding_norm"],
        "neighbor_overlap": overlap,
        "pre_top5": [n["token"] for n in pre["top_neighbors"][:5]],
        "post_top5": [n["token"] for n in post["top_neighbors"][:5]]
    }

    print(f"\n{term}:")
    print(f"  Norm: {pre['embedding_norm']:.4f} → {post['embedding_norm']:.4f}")
    print(f"  Overlap: {overlap}/5")
    print(f"  Before: {', '.join(comparison[term]['pre_top5'])}")
    print(f"  After:  {', '.join(comparison[term]['post_top5'])}")

comparison_path = OUTDIR / "results" / "chaos_comparison.json"
with open(comparison_path, "w") as f:
    json.dump(comparison, f, indent=2)

# Save final model
final_dir = OUTDIR / "final_chaos_model"
model.save_pretrained(final_dir)
tok.save_pretrained(final_dir)

print(f"\nChaos model saved: {final_dir}")
print(f"Results: {OUTDIR / 'results'}")