# Wake2Vec (Resume + Sentry Helpers)

This notebook safely resumes Phase-1 (embeddings-only) training from the latest valid checkpoint and mirrors every new checkpoint + metrics to a secondary backup (`sentry_backups/`). It also includes a P1-finalize cell for geometry metrics (overlap@5, Δ‖E‖) and a loss plot. Use this notebook alongside the main training notebook.

---

## Notebook insights

- **Pin + Shim:** Ensures compatible versions (`transformers==4.57.1`, `accelerate==1.2.1`, `datasets==2.20.0`) and patches `Accelerator.unwrap_model(...)` to ignore unknown kwargs.
- **GPU Probe:** Verifies the T4 is visible and a tiny forward pass works.
- **Safe Resume:** Finds the newest checkpoint with weights (in `runs/<RUN_ID>/` or `sentry_backups/<RUN_ID>/`) and resumes to your `TARGET` step (e.g., 1300) with `save_steps=100`.
- **SentryMirror:** On each save, mirrors the newest checkpoint and metrics JSON to `sentry_backups/<RUN_ID>/...` for extra safety.
- **EvalEveryNSteps:** Triggers eval every 200 steps (works even if `evaluation_strategy` isn’t supported in your env).
- **P1-Finalize:** After training lands, computes overlap@5 and mean Δ‖E‖ for the new tokens, and writes a clean loss plot.

---

## Assumptions / Paths

- **Drive base:** `/content/drive/MyDrive/wake2vec`
- **Runs:** `.../runs/<RUN_ID>/` (created by the main notebook)
- **Backups:** `.../sentry_backups/<RUN_ID>/`
- **Datasets (prebuilt):** `.../datasets/train_ds` and `.../datasets/valid_ds`
- **Model base:** `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`
- **New-token geometry inputs (optional for P1-finalize):**  
  `E_comp.npy` (composed init vectors) and `new_ids.npy` (indices of new rows)

> If `E_comp.npy` / `new_ids.npy` are missing, finalize still runs and just skips overlap/Δ‖E‖.



Mid-run health ping

In [None]:
# progress ping
import json, pathlib, time

CANDIDATE_BASES = [
    pathlib.Path("/content/runs"),
    pathlib.Path("/content/drive/MyDrive/wake2vec/runs"),
]

def latest_run_dir():
    dirs = []
    for base in CANDIDATE_BASES:
        if base.exists():
            dirs += list(base.glob("t4_*"))
    if not dirs:
        return None
    # most-recent mtime
    return max(dirs, key=lambda p: p.stat().st_mtime)

def read_trainer_state(run_dir):
    for p in [run_dir/"trainer_state.json", run_dir/"checkpoint-final"/"trainer_state.json"]:
        if p.exists():
            try:
                return json.loads(p.read_text())
            except Exception:
                pass
    return None

def read_loss_log(run_dir):
    metrics = run_dir/"metrics"
    for name in ["phase1_loss_log.json", "phase2_loss_log.json"]:
        p = metrics/name
        if p.exists():
            try:
                return json.loads(p.read_text())
            except Exception:
                pass
    return None

run = latest_run_dir()
if run is None:
    print("No t4_* run directory found in /content/runs or Drive. Is your run writing to Drive?")
else:
    state = read_trainer_state(run)
    if state and "log_history" in state:
        logs = [d for d in state["log_history"] if "loss" in d]
        if logs:
            # compute EMA
            ema = 0.0
            alpha = 0.1
            for d in logs:
                L = float(d["loss"])
                ema = L if ema == 0 else alpha*L + (1-alpha)*ema
            last = logs[-1]
            print(f"{run.name} | step {last['step']} | loss {float(last['loss']):.4f} | EMA {ema:.4f}")
        else:
            print(f"{run.name} | trainer_state.json has no loss logs yet.")
    else:
        # fallback
        recs = read_loss_log(run) or []
        if recs:
            ema = 0.0; alpha = 0.1
            for r in recs:
                L = float(r["loss"])
                ema = L if ema == 0 else alpha*L + (1-alpha)*ema
            last = recs[-1]
            print(f"{run.name} | step {last.get('step','?')} | loss {float(last['loss']):.4f} | EMA {ema:.4f} (from loss log)")
        else:
            print(f"{run.name} | no trainer_state or loss log found yet — try again in ~30–60s.")

t4_1762376560 | step 550 | loss 3.2388 | EMA 6.1215 (from loss log)


In [None]:
# quick glance
import json, pathlib
run = max([*pathlib.Path("/content/runs").glob("t4_*"),
           *pathlib.Path("/content/drive/MyDrive/wake2vec/runs").glob("t4_*")],
          key=lambda p: p.stat().st_mtime)
logs = json.loads((run/"metrics"/"phase1_loss_log.json").read_text())
print("\n".join(f"{r['step']:>5}  {r['loss']:.4f}" for r in logs[-5:]))

  450  4.7429
  500  4.0474
  550  3.2388
  600  2.4137
  650  1.5350


In [None]:
import json, pathlib

def latest_run():
    bases = [pathlib.Path("/content/runs"),
             pathlib.Path("/content/drive/MyDrive/wake2vec/runs")]
    dirs = [p for b in bases if b.exists() for p in b.glob("t4_*")]
    return max(dirs, key=lambda p: p.stat().st_mtime) if dirs else None

def load_state(run):
    root = run/"trainer_state.json"
    if root.exists():
        try: return json.loads(root.read_text()), root
        except: pass
    cps = sorted(run.glob("checkpoint-*/trainer_state.json"), key=lambda p: p.stat().st_mtime, reverse=True)
    for p in cps:
        try: return json.loads(p.read_text()), p
        except: pass
    return None, None

def load_loss_log(run):
    for name in ["phase1_loss_log.json", "phase2_loss_log.json"]:
        p = run/"metrics"/name
        if p.exists():
            try: return json.loads(p.read_text())
            except: pass
    return None

run = latest_run()
if not run:
    print("No run folder found in /content/runs or Drive.")
else:
    state, where = load_state(run)
    if state:
        logs = [d for d in state.get("log_history", []) if "loss" in d or "eval_loss" in d]
        if logs:
            last = logs[-1]
            if "eval_loss" in last:
                print(f"{run.name} | {where.parent.name} | step {last['step']} | eval_loss {last['eval_loss']:.4f}")
            else:
                print(f"{run.name} | {where.parent.name} | step {last['step']} | loss {float(last['loss']):.4f}")
        else:
            print(f"{run.name} | {where.parent.name} | trainer_state has no logs yet.")
    else:
        recs = load_loss_log(run)
        if recs:
            last = recs[-1]
            print(f"{run.name} | metrics log | step {last.get('step','?')} | loss {float(last['loss']):.4f}")
        else:
            print(f"{run.name} | no trainer_state or loss log visible yet.")

t4_1762376560 | checkpoint-700 | step 700 | loss 0.7782


In [None]:
import pathlib, shutil
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max([*DRIVE.glob("runs/t4_*")], key=lambda p: p.stat().st_mtime)
ckpts = sorted(run.glob("checkpoint-*"), key=lambda p: p.stat().st_mtime, reverse=True)
assert ckpts, "No checkpoints yet."
src = ckpts[0]
dst = DRIVE/"sentry_backups"/run.name/src.name
dst.parent.mkdir(parents=True, exist_ok=True)
if not dst.exists():
    shutil.copytree(src, dst)
print("[SENTRY] Backed up", src, "→", dst)

[SENTRY] Backed up /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-700 → /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/checkpoint-700


In [None]:
import shutil, pathlib
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max([*DRIVE.glob("runs/t4_*")], key=lambda p: p.stat().st_mtime)
dstm = DRIVE/"sentry_backups"/run.name/"metrics"
dstm.mkdir(parents=True, exist_ok=True)
for f in (run/"metrics").glob("*.json"):
    shutil.copy2(f, dstm/f.name)
print("[SENTRY] Metrics mirrored to", dstm)

[SENTRY] Metrics mirrored to /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/metrics


In [None]:
# backup newest checkpoint to Drive/sentry_backups
import pathlib, shutil
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max([*DRIVE.glob("runs/t4_*")], key=lambda p: p.stat().st_mtime)
ckpts = sorted(run.glob("checkpoint-*"), key=lambda p: p.stat().st_mtime, reverse=True)
if ckpts:
    src = ckpts[0]
    dst = DRIVE/"sentry_backups"/run.name/src.name
    if not dst.exists():
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copytree(src, dst)
        print("[SENTRY] Backed up", src.name)
else:
    print("No checkpoint yet.")
# mirror metrics
dstm = DRIVE/"sentry_backups"/run.name/"metrics"; dstm.mkdir(parents=True, exist_ok=True)
for f in (run/"metrics").glob("*.json"): shutil.copy2(f, dstm/f.name)
print("[SENTRY] Metrics mirrored →", dstm)

[SENTRY] Metrics mirrored → /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/metrics


In [None]:
import json, pathlib
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max([*DRIVE.glob("runs/t4_*")], key=lambda p: p.stat().st_mtime)
logs = json.loads((run/"metrics"/"phase1_loss_log.json").read_text())
print("step", logs[-1]["step"], "loss", round(float(logs[-1]["loss"]),4))

step 650 loss 1.535


tbc


In [None]:
# Pin + Shim
import sys, subprocess, importlib, inspect, os

TARGET = {"transformers":"4.57.1", "accelerate":"1.2.1", "datasets":"2.20.0"}

for pkg, ver in TARGET.items():
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
    print(f"[OK] {pkg} {m.__version__}")

# Accelerate unwrap_model compat
from accelerate import Accelerator
sig = inspect.signature(Accelerator.unwrap_model)
if "keep_torch_compile" not in sig.parameters:
    _orig = Accelerator.unwrap_model
    def _compat(self, model, *a, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *a, **kw)
    Accelerator.unwrap_model = _compat
    print("[PATCH] unwrap_model compat applied")

# conservative CUDA settings
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

!nvidia-smi -L || echo "No GPU visible"

[OK] transformers 4.57.1
[OK] accelerate 1.11.0
[OK] datasets 4.0.0
GPU 0: Tesla T4 (UUID: GPU-871dba84-49a5-27fe-ae76-5453b21a035c)


In [None]:
# correct versions
import sys, subprocess, importlib, os

PINS = {
    "transformers": "4.57.1",
    "accelerate":   "1.2.1",
    "datasets":     "2.20.0",
}

to_install = []
for pkg, ver in PINS.items():
    try:
        m = importlib.import_module(pkg)
        if m.__version__ != ver:
            to_install.append(f"{pkg}=={ver}")
    except Exception:
        to_install.append(f"{pkg}=={ver}")

if to_install:
    print("[INSTALL]", to_install)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", *to_install])
    print("[RESTART] Restarting runtime for pinned versions…")
    try:
        # for Colab </3
        from google.colab import runtime as _colab_runtime
        _colab_runtime.unassign()
    except Exception:
        os._exit(0)
else:
    print("[OK] versions already pinned.")

[INSTALL] ['accelerate==1.2.1', 'datasets==2.20.0']
[RESTART] Restarting runtime for pinned versions…


In [None]:
# tiny GPU probe
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

BASE = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

tok = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(
    BASE, torch_dtype=torch.float32, device_map="auto"
)
model.config.use_cache = False

# tie lm_head to input embeddings (safety re-tie)
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight

# one tiny forward on GPU
x = tok("riverrun past Eve and Adam's", return_tensors="pt")
x = {k: v.to(next(model.parameters()).device) for k, v in x.items()}
_ = model(**x)

print("[GPU OK]", next(model.parameters()).device, "| logits:", _.logits.shape)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

[GPU OK] cuda:0 | logits: torch.Size([1, 10, 32000])


In [None]:
# RESUME P1 → 1300  eval via callback
import pathlib, json, shutil, torch
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, TrainingArguments, Trainer, TrainerCallback)

DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN   = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)

def list_ckpts(base):
    if not base.exists(): return []
    return sorted([p for p in base.glob("checkpoint-*") if p.is_dir()],
                  key=lambda p: int(p.name.split("-")[-1]))

def has_weights(ck):
    return (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()

# pick newest checkpoint
ck_main   = list_ckpts(RUN)
ck_sentry = list_ckpts(DRIVE/"sentry_backups"/RUN.name)
cands = sorted(ck_main + ck_sentry, key=lambda p: int(p.name.split("-")[-1]), reverse=True)
GOOD = next((p for p in cands if has_weights(p)), None)
assert GOOD is not None, "No valid checkpoint (with weights) found."

last_step = int(GOOD.name.split("-")[-1])
TARGET = 1300
print(f"[RESUME] {RUN.name} from {GOOD.name} (last_step={last_step}) → target={TARGET}")

# load model/tokenizer & tie head
tok = AutoTokenizer.from_pretrained(str(GOOD), use_fast=True)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(str(GOOD), torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight

# data
train_ds = load_from_disk(str(DRIVE/"datasets"/"train_ds"))
valid_ds = load_from_disk(str(DRIVE/"datasets"/"valid_ds"))
valid_small = valid_ds.select(range(min(1000, len(valid_ds))))
dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

# callbacks
class EvalEveryNSteps(TrainerCallback):
    def __init__(self, n=200): self.n=n
    def on_step_end(self, args, state, control, **kw):
        if state.global_step and state.global_step % self.n == 0:
            control.should_evaluate = True

class LossStreamer(TrainerCallback):
    def __init__(self, path, log_every=50): self.p, self.le, self.recs = path, log_every, []
    def on_log(self, args, state, control, logs=None, **kw):
        if not logs or "loss" not in logs: return
        s=int(state.global_step or 0); L=float(logs["loss"])
        self.recs.append({"step":s,"loss":L,"lr":logs.get("learning_rate")})
        if s % self.le == 0 and s>0:
            import json; open(self.p,"w").write(json.dumps(self.recs, indent=2))
            print(f"[P1 {s}] loss={L:.4f}")

class SentryMirror(TrainerCallback):
    def on_save(self, args, state, control, **kw):
        try:
            ck = max(RUN.glob("checkpoint-*"), key=lambda p: p.stat().st_mtime)
            dst = DRIVE/"sentry_backups"/RUN.name/ck.name
            if not dst.exists():
                dst.parent.mkdir(parents=True, exist_ok=True)
                shutil.copytree(ck, dst)
                print(f"[SENTRY] mirrored {ck.name}")
            mdst = DRIVE/"sentry_backups"/RUN.name/"metrics"
            mdst.mkdir(parents=True, exist_ok=True)
            for f in (RUN/"metrics").glob("*.json"):
                shutil.copy2(f, mdst/f.name)
        except Exception as e:
            print("[SENTRY] mirror failed:", e)

# args
args = TrainingArguments(
    output_dir=str(RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=12,
    gradient_checkpointing=True,
    fp16=False, bf16=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_small,
    data_collator=dc,
    callbacks=[EvalEveryNSteps(200), LossStreamer(str(RUN/"metrics"/"phase1_loss_log.json")), SentryMirror()],
)

trainer.train(resume_from_checkpoint=str(GOOD))
trainer.save_model(str(RUN/"checkpoint-final")); tok.save_pretrained(str(RUN/"checkpoint-final"))
print("[DONE] Reached", TARGET)

[RESUME] t4_1762376560 from checkpoint-300 (last_step=300) → target=1300


Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-300 and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The model is already on multiple devices. Skipping the move to device specified in `args`.
There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


Step,Training Loss,Validation Loss
350,5.4883,
400,5.3082,6.273335
450,4.7776,No Log
500,4.0891,No Log
550,3.2604,No Log
600,2.4259,7.170007


[P1 350] loss=5.4883
[P1 400] loss=5.3082
[SENTRY] mirrored checkpoint-400
[P1 450] loss=4.7776
[P1 500] loss=4.0891
[SENTRY] mirrored checkpoint-500
[P1 550] loss=3.2604
[P1 600] loss=2.4259
