<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2Vec_750_1300.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2Vec: Lexicon-Augmented Embedding Training (Steps 750-1300)

Continuation training for embedding-only fine-tuning of TinyLlama-1.1B with Finnegans Wake lexicon injection.

---

## Overview

This notebook implements a continuation phase of embedding-only fine-tuning for a large language model augmented with approximately 44,000 Wake-specific tokens. The training methodology isolates new token embeddings through gradient masking while keeping pre-trained model parameters frozen, enabling lexical integration into the existing semantic space without catastrophic forgetting.

## Training Configuration

Phase: Continuation from checkpoint 750 to step 1300

Model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T

Training strategy: Embedding-only optimization with gradient masking on base vocabulary

Data: Finnegans Wake corpus with held-out validation set

Optimization: Adafactor (learning rate 5e-4, no warmup)

Hardware: Single T4 GPU (15GB VRAM) via Google Colab

## Notebook Structure

This notebook is organized into six sequential cells for reproducible execution:

1. Path configuration and environment setup
2. Dependency pinning and compatibility patches
3. Model and tokenizer loading with gradient masking
4. Dataset loading and sequence truncation
5. Training callback definitions (monitoring, backups, snapshots)
6. Training execution and resumption logic

## Monitoring and Backup Utilities

The notebook provides automated systems for long-running training stability:

- Loss and evaluation metric tracking via structured JSON logs
- Checkpoint inventory and validation (weights, optimizer state, trainer state)
- Sentry mirror system for automated checkpoint backups to Google Drive
- Embedding snapshot capture at 50-step intervals
- Training throughput diagnostics (steps per second reporting)
- Heartbeat metadata for remote monitoring

## Implementation Details

The training resumes from checkpoint-750 and proceeds with the following hyperparameters:

- Batch size: 1 (effective batch size 16 via gradient accumulation)
- Gradient accumulation steps: 16
- Learning rate: 5e-4 (no warmup)
- Sequence length: 384 tokens (runtime truncation)
- Save frequency: 75 steps
- Evaluation frequency: 200 steps
- Checkpoint retention: 3 most recent (memory optimization)
- Gradient clipping: maximum norm 1.0
- Gradient checkpointing: enabled for memory efficiency

All training callbacks (evaluation triggers, sentry mirroring, embedding snapshots, throughput monitoring) are integrated into the Hugging Face Trainer workflow for automated execution during training. The notebook includes compatibility patches for transformers 4.57.1 and accelerate 1.2.1 to ensure stable execution on Colab environments.

Archive ⬇

In [None]:
import os
from pathlib import Path
from google.colab import drive

drive.mount('/content/drive', force_remount=False)

# Path config
DRIVE = Path("/content/drive/MyDrive/wake2vec")
RUN_ID = "t4_1762376560"
DATASETS = Path("/content/datasets")
LOCAL_RUN = Path("/content/runs") / RUN_ID
SENTRY = DRIVE / "sentry_backups" / RUN_ID
RESUME_FROM = DRIVE / "runs" / RUN_ID / "checkpoint-750-rebuilt"

# directories
LOCAL_RUN.mkdir(parents=True, exist_ok=True)
SENTRY.mkdir(parents=True, exist_ok=True)

# Training config
FINAL_TARGET = 1300
LAST_STEP = 750
TARGET = min(FINAL_TARGET, LAST_STEP + 550)

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

print("[PATH CONFIGURATION]")
print(f"  RUN_ID: {RUN_ID}")
print(f"  RESUME_FROM: {RESUME_FROM}")
print(f"  LOCAL_RUN: {LOCAL_RUN}")
print(f"  SENTRY: {SENTRY}")
print(f"  DATASETS: {DATASETS}")
print(f"  TARGET: {LAST_STEP} → {TARGET}")

Mounted at /content/drive
[PATH CONFIGURATION]
  RUN_ID: t4_1762376560
  RESUME_FROM: /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt
  LOCAL_RUN: /content/runs/t4_1762376560
  SENTRY: /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560
  DATASETS: /content/datasets
  TARGET: 750 → 1300


In [None]:
# Dependency Pin & Compat Patches
import sys
import subprocess
import importlib

def pin(pkg, ver):
    """Pin package to specific version"""
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
        print(f"[OK] {pkg} {m.__version__}")
    except Exception:
        print(f"[PIN] {pkg}=={ver}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg} {m.__version__}")

pin("transformers", "4.57.1")
pin("accelerate", "1.2.1")
pin("datasets", "2.21.0")

# unwrap_model compatibility patch
import accelerate
if not hasattr(accelerate.Accelerator, "_w2v_patched"):
    _orig = accelerate.Accelerator.unwrap_model
    def _shim(self, model, *args, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *args, **kw)
    accelerate.Accelerator.unwrap_model = _shim
    accelerate.Accelerator._w2v_patched = True
    print("[PATCH] unwrap_model compatibility shim active")

[OK] transformers 4.57.1
[OK] accelerate 1.2.1
[OK] datasets 2.21.0
[PATCH] unwrap_model compatibility shim active


In [None]:
import gc, torch, os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb=64"

for k in ["trainer","model","tok","optimizer","scheduler","dc","train_ds","valid_ds"]:
    if k in globals():
        try: del globals()[k]
        except: pass
gc.collect(); torch.cuda.empty_cache()
try: torch.cuda.ipc_collect()
except: pass

In [None]:
# Model and Tok
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer

# Clear GPU
torch.cuda.empty_cache()
gc.collect()

# Load tok
tok = AutoTokenizer.from_pretrained(str(RESUME_FROM), use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(
    str(RESUME_FROM),
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True,
)

# Config model
model.config.use_cache = False
model.config.attn_implementation = "eager"

# freeze all except embeddings
for p in model.parameters():
    p.requires_grad = False

emb = model.get_input_embeddings()
emb.weight.requires_grad = True

# Tie head
with torch.no_grad():
    model.get_output_embeddings().weight = emb.weight

model.train()

# Verification
print("[MODEL LOADED]")
print(f"  Tied weights: {'OK' if emb.weight.data_ptr() == model.get_output_embeddings().weight.data_ptr() else 'NO'}")
print(f"  Embedding trainable: {emb.weight.requires_grad}")
print(f"  Pad token ID: {tok.pad_token_id}")

import collections
devs = collections.Counter(p.device.type for p in model.parameters())
print(f"  Device distribution: {dict(devs)}")

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[MODEL LOADED]
  Tied weights: OK
  Embedding trainable: True
  Pad token ID: 2
  Device distribution: {'cuda': 200}


In [None]:
# Callbacks
import json
import time
import shutil
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerState

class EvalEveryNSteps(TrainerCallback):
    """Trigger evaluation at fixed step intervals"""
    def __init__(self, n=200):
        self.n = n

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        if s and s % self.n == 0:
            control.should_log = True
            control.should_evaluate = True

class SentryMirror(TrainerCallback):
    """Mirror checkpoints to backup directory on Drive"""
    def on_save(self, args, state, control, **kw):
        try:
            out = Path(args.output_dir)
            cks = [p for p in out.glob("checkpoint-*") if p.is_dir()]
            if not cks:
                return

            def step_of(p):
                try:
                    return int(p.name.split("-")[-1])
                except:
                    return -1

            ck = max(cks, key=step_of)
            has_weights = (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()

            if not has_weights:
                return

            dst = SENTRY / ck.name
            if not dst.exists():
                shutil.copytree(ck, dst)
                print(f"[SENTRY] mirrored {ck.name}")

            # Mirror metrics
            msrc = out / "metrics"
            if msrc.exists():
                (SENTRY/"metrics").mkdir(parents=True, exist_ok=True)
                for f in msrc.glob("*.json"):
                    shutil.copy2(f, SENTRY/"metrics"/f.name)

        except Exception as e:
            print(f"[SENTRY] mirror failed: {e}")

class EmbeddingSnap(TrainerCallback):
    """Save embedding snapshots at regular intervals"""
    def __init__(self, every=50):
        self.every = every
        (DRIVE/"emb_snaps"/RUN_ID).mkdir(parents=True, exist_ok=True)

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        if s and s % self.every == 0:
            try:
                E = model.get_input_embeddings().weight.detach().cpu()
                out_path = (DRIVE/"emb_snaps"/RUN_ID) / f"emb_step{s:04d}.pt"
                torch.save(E, out_path)

                # Write heartbeat metadata
                heartbeat = {
                    "run_id": RUN_ID,
                    "step": s,
                    "rows": int(E.size(0)),
                    "dim": int(E.size(1)),
                    "ts": time.time()
                }
                (out_path.parent/"heartbeat.json").write_text(
                    json.dumps(heartbeat, indent=2))

                print(f"[SNAP] embeddings saved to {out_path.name}")

            except Exception as e:
                print(f"[SNAP] failed: {e}")

class StepTimer(TrainerCallback):
    """Monitor training throughput in steps per second"""
    def __init__(self, every=10):
        self.prev = None
        self.t = None
        self.every = every

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        now = time.time()

        if self.prev is not None and s > self.prev and s % self.every == 0:
            dt = now - self.t
            print(f"[{s:4d}] ~{dt/self.every:.2f}s/step (last {self.every})")

        self.prev, self.t = s, now

print("[CALLBACKS CONFIGURED]")

[CALLBACKS CONFIGURED]


In [None]:
# Pin and patch FIRST
import sys, subprocess, importlib

def pin(pkg, ver):
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
        print(f"[OK] {pkg} {m.__version__}")
    except Exception:
        print(f"[PIN] {pkg}=={ver}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg} {m.__version__}")

pin("transformers", "4.57.1")
pin("accelerate", "1.2.1")
pin("datasets", "2.21.0")

import accelerate
if not hasattr(accelerate.Accelerator, "_w2v_patched"):
    _orig = accelerate.Accelerator.unwrap_model
    def _shim(self, model, *args, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *args, **kw)
    accelerate.Accelerator.unwrap_model = _shim
    accelerate.Accelerator._w2v_patched = True
    print("[PATCH] active")

[OK] transformers 4.57.1
[OK] accelerate 1.2.1
[OK] datasets 2.21.0
[PATCH] active


In [None]:
# DATASETS
from pathlib import Path
import shutil

try:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)
except Exception:
    pass

DRIVE = Path("/content/drive/MyDrive/wake2vec")
DATASETS = Path("/content/datasets"); DATASETS.mkdir(parents=True, exist_ok=True)

def ensure_dir(src: Path, dst: Path):
    assert src.exists(), f"Missing dataset at {src}"
    if not dst.exists():
        print(f"[DATA] copying {src} → {dst}")
        shutil.copytree(src, dst)
    else:
        print(f"[DATA] already local:", dst)

# try to copy
src_train = DRIVE/"datasets"/"train_ds"
src_valid = DRIVE/"datasets"/"valid_ds"
dst_train = DATASETS/"train_ds"
dst_valid = DATASETS/"valid_ds"
ensure_dir(src_train, dst_train)
ensure_dir(src_valid, dst_valid)

# load with a Drive fallback
try:
    train_ds_path = str(dst_train if dst_train.exists() else src_train)
    valid_ds_path = str(dst_valid if dst_valid.exists() else src_valid)
    print("[DATA] train_ds:", train_ds_path)
    print("[DATA] valid_ds:", valid_ds_path)
except Exception as e:
    raise RuntimeError(f"Dataset path resolution failed: {e}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[DATA] already local: /content/datasets/train_ds
[DATA] already local: /content/datasets/valid_ds
[DATA] train_ds: /content/datasets/train_ds
[DATA] valid_ds: /content/datasets/valid_ds


In [None]:
# FORCE CLEAN GPU
import gc
import torch

for name in list(globals().keys()):
    if 'model' in name.lower() or 'trainer' in name.lower():
        try:
            del globals()[name]
        except:
            pass

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

print("[MEMORY] GPU cleaned")

[MEMORY] GPU cleaned


In [None]:
import gc
import torch

for name in list(globals().keys()):
    if 'model' in name.lower() or 'trainer' in name.lower():
        try:
            del globals()[name]
        except:
            pass

gc.collect()
torch.cuda.empty_cache()
print("[MEMORY] Cleaned")

[MEMORY] Cleaned


for p1 eval

In [None]:
print("TRAIN output_dir:", trainer.args.output_dir)
print("RUN_ID:", pathlib.Path(trainer.args.output_dir).name)
print("global_step:", trainer.state.global_step)

In [None]:
import inspect
from transformers import TrainingArguments
'evaluation_strategy' in inspect.signature(TrainingArguments.__init__).parameters

In [None]:
from transformers import TrainerCallback

class EvalEveryNSteps(TrainerCallback):
    def __init__(self, n=200): self.n=n
    def on_step_end(self, args, state, control, **kw):
        if state.global_step and (state.global_step % self.n == 0):
            control.should_save = True         # ensure state gets flushed
            control.should_log = True
            control.should_evaluate = True

In [None]:
# P1 finalise overlap@5, norm stats, and a loss plot
import json, numpy as np, pathlib, matplotlib.pyplot as plt

DRIVE_ROOT = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUNS = sorted((DRIVE_ROOT/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
RUN_DIR = pathlib.Path("/content/runs")/RUNS[-1].name
METRICS_DIR = RUN_DIR/"metrics"
PLOTS_DIR = RUN_DIR/"plots"; PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# Load current embeddings
from transformers import AutoModelForCausalLM
BASE_CKPT = RUN_DIR/"checkpoint-final"
model = AutoModelForCausalLM.from_pretrained(str(BASE_CKPT), torch_dtype="float32", device_map=None)
E_post = model.get_input_embeddings().weight.detach().cpu().numpy()

# Optional composed init + ids
E_COMP = DRIVE_ROOT/"E_comp.npy"
NEW_IDS = DRIVE_ROOT/"new_ids.npy"
has_comp = E_COMP.exists() and NEW_IDS.exists()

def topk_overlap(a, b, k=5):
    import numpy as np
    from numpy.linalg import norm
    a = a / (norm(a, axis=1, keepdims=True)+1e-9)
    b = b / (norm(b, axis=1, keepdims=True)+1e-9)
    sims = a @ b.T
    top_a = np.argsort(-sims, axis=1)[:, :k]
    top_b = np.argsort(-sims, axis=1)[:, :k]
    inter = np.array([len(set(top_a[i]) & set(top_b[i])) for i in range(a.shape[0])])
    return inter.mean()

report = {}

if has_comp:
    E_comp = np.load(E_COMP)
    new_ids = np.load(NEW_IDS)
    E_post_new = E_post[new_ids]
    overlap5 = topk_overlap(E_comp, E_post_new, k=5)
    # Norm drift
    from numpy.linalg import norm
    dn = (norm(E_post_new, axis=1) - norm(E_comp, axis=1)).mean()
    report.update({"overlap_at_5": float(overlap5), "mean_delta_norm": float(dn), "n_new": int(len(new_ids))})
else:
    # Fallback
    from numpy.linalg import norm
    norms = norm(E_post, axis=1)
    report.update({"post_mean_norm": float(norms.mean()), "post_std_norm": float(norms.std()), "n_vocab": int(E_post.shape[0])})

# JSON
(METRICS_DIR/"p1_summary.json").write_text(json.dumps(report, indent=2))
print("[P1 SUMMARY]", json.dumps(report, indent=2))

# Loss plot
import json, glob
state_files = [RUN_DIR/"trainer_state.json", BASE_CKPT/"trainer_state.json"]
state_files = [p for p in state_files if p.exists()]
logs = []
for sf in state_files:
    s = json.loads(sf.read_text())
    logs.extend([d for d in s.get("log_history", []) if "loss" in d])

if logs:
    steps = [d["step"] for d in logs]
    losses = [float(d["loss"]) for d in logs]
    ema = []
    alpha = 0.1
    for i,x in enumerate(losses):
        ema.append(x if i==0 else alpha*x + (1-alpha)*ema[-1])
    plt.figure(figsize=(7,4.5))
    plt.plot(steps, losses, label="loss")
    plt.plot(steps, ema, label="EMA(0.1)")
    plt.title(f"P1 Loss — {RUN_DIR.name}")
    plt.xlabel("step"); plt.ylabel("loss"); plt.grid(True, linewidth=0.3); plt.legend()
    outp = PLOTS_DIR/"p1_loss_curve.png"
    plt.savefig(outp, dpi=140, bbox_inches="tight")
    print("[PLOT]", outp)
else:
    print("[WARN] No trainer_state logs found; skip loss plot.")

In [None]:
# Training Config and Execution
from transformers import TrainingArguments, Trainer

# Training args
args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=25,
    save_strategy="steps",
    save_steps=75,
    save_total_limit=3,
    gradient_checkpointing=True,
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    dataloader_persistent_workers=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc_trunc,
    callbacks=[
        EvalEveryNSteps(200),
        SentryMirror(),
        EmbeddingSnap(50),
        StepTimer(10)
    ],
)

# resume state
ts = RESUME_FROM / "trainer_state.json"
if not ts.exists():
    state_dict = {
        "global_step": LAST_STEP,
        "max_steps": TARGET,
        "log_history": []
    }
    ts.write_text(json.dumps(state_dict, indent=2))
    print(f"[RESUME] wrote trainer_state.json at step {LAST_STEP}")

# training
print(f"[GO] Resuming from step {LAST_STEP} with MAX_LEN={MAX_LEN}")
print(f"     Target: {TARGET} | Gradient checkpointing: ON")

t0 = time.time()
trainer.train(resume_from_checkpoint=str(RESUME_FROM))
elapsed = time.time() - t0

print(f"[COMPLETE] Training finished in {elapsed:.1f}s ({elapsed/60:.1f}min)")