<a href="https://colab.research.google.com/github/mahb97/Wakeifier/blob/main/Wake2Vec_750_1300.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2Vec: Lexicon-Augmented Embedding Training (Steps 750-1300)

Continuation training for embedding-only fine-tuning of TinyLlama-1.1B with Finnegans Wake lexicon injection.

---

## Overview

This notebook implements a continuation phase of embedding-only fine-tuning for a large language model augmented with approximately 44,000 Wake-specific tokens. The training methodology isolates new token embeddings through gradient masking while keeping pre-trained model parameters frozen, enabling lexical integration into the existing semantic space without catastrophic forgetting.

## Training Configuration

Phase: Continuation from checkpoint 750 to step 1300

Model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T

Training strategy: Embedding-only optimization with gradient masking on base vocabulary

Data: Finnegans Wake corpus with held-out validation set

Optimization: Adafactor (learning rate 5e-4, no warmup)

Hardware: Single T4 GPU (15GB VRAM) via Google Colab

## Notebook Structure

This notebook is organized into six sequential cells for reproducible execution:

1. Path configuration and environment setup
2. Dependency pinning and compatibility patches
3. Model and tokenizer loading with gradient masking
4. Dataset loading and sequence truncation
5. Training callback definitions (monitoring, backups, snapshots)
6. Training execution and resumption logic

## Monitoring and Backup Utilities

The notebook provides automated systems for long-running training stability:

- Loss and evaluation metric tracking via structured JSON logs
- Checkpoint inventory and validation (weights, optimizer state, trainer state)
- Sentry mirror system for automated checkpoint backups to Google Drive
- Embedding snapshot capture at 50-step intervals
- Training throughput diagnostics (steps per second reporting)
- Heartbeat metadata for remote monitoring

## Implementation Details

The training resumes from checkpoint-750 and proceeds with the following hyperparameters:

- Batch size: 1 (effective batch size 16 via gradient accumulation)
- Gradient accumulation steps: 16
- Learning rate: 5e-4 (no warmup)
- Sequence length: 384 tokens (runtime truncation)
- Save frequency: 75 steps
- Evaluation frequency: 200 steps
- Checkpoint retention: 3 most recent (memory optimization)
- Gradient clipping: maximum norm 1.0
- Gradient checkpointing: enabled for memory efficiency

All training callbacks (evaluation triggers, sentry mirroring, embedding snapshots, throughput monitoring) are integrated into the Hugging Face Trainer workflow for automated execution during training. The notebook includes compatibility patches for transformers 4.57.1 and accelerate 1.2.1 to ensure stable execution on Colab environments.

In [None]:
import json, pathlib
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
log = json.loads((run/"metrics"/"phase1_loss_log.json").read_text())
print(run.name, "step", log[-1]["step"], "loss", round(float(log[-1]["loss"]),4))

t4_1762376560 step 550 loss 3.2604


In [None]:
import json, pathlib
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
run = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
ck = sorted(run.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1]), reverse=True)[0]
state = json.loads((ck/"trainer_state.json").read_text())
ev = [d for d in state.get("log_history", []) if "eval_loss" in d]
print(ev[-1] if ev else "no eval yet")

{'epoch': 10.527472527472527, 'eval_loss': 7.096441268920898, 'eval_runtime': 13.6439, 'eval_samples_per_second': 3.518, 'eval_steps_per_second': 0.44, 'step': 600}


snapshot

In [None]:
# Manual snapshot
import pathlib, shutil, time
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN   = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)

def latest_ckpt_with_weights(base):
    cands = sorted(base.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1]), reverse=True)
    for ck in cands:
        if (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists():
            return ck
    return None

src = latest_ckpt_with_weights(RUN)
assert src is not None, "No valid checkpoint with weights found yet."
SNAPS = RUN/"snapshots"; SNAPS.mkdir(exist_ok=True, parents=True)
dst = SNAPS/f"snap_{int(time.time())}_{src.name}"
if not dst.exists():
    shutil.copytree(src, dst)
print("[SNAP] Saved", dst)

[SNAP] Saved /content/drive/MyDrive/wake2vec/runs/t4_1762376560/snapshots/snap_1762565503_checkpoint-300


sentry mirror

In [None]:
# newest full checkpoint + metrics
import pathlib, shutil
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN   = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
SENTRY = DRIVE/"sentry_backups"/RUN.name; SENTRY.mkdir(parents=True, exist_ok=True)

def latest_ckpt_with_weights(base):
    cands = sorted(base.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1]), reverse=True)
    for ck in cands:
        if (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists():
            return ck
    return None

ck = latest_ckpt_with_weights(RUN)
if ck is None:
    print("[SENTRY] No full checkpoint yet.")
else:
    dst = SENTRY/ck.name
    if not dst.exists():
        shutil.copytree(ck, dst)
        print(f"[SENTRY] Mirrored {ck.name} → {dst}")
    else:
        print("[SENTRY] Already mirrored:", dst)

# mirror metrics JSONs
mdst = SENTRY/"metrics"; mdst.mkdir(parents=True, exist_ok=True)
for f in (RUN/"metrics").glob("*.json"):
    shutil.copy2(f, mdst/f.name)
print("[SENTRY] Metrics mirrored →", mdst)

[SENTRY] Mirrored checkpoint-300 → /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/checkpoint-300
[SENTRY] Metrics mirrored → /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/metrics


loss and checkpoint

In [None]:
# Heartbeat
import json, pathlib
DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN   = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)

# loss
mlog = RUN/"metrics"/"phase1_loss_log.json"
if mlog.exists():
    logs = json.loads(mlog.read_text())
    print(f"[LOSS] step={logs[-1]['step']}  loss={float(logs[-1]['loss']):.4f}")
else:
    print("[LOSS] no metrics yet")

# checkpoints
def scan(base):
    rows=[]
    for ck in sorted(base.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1])):
        s=int(ck.name.split("-")[-1])
        rows.append((
            s,
            (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists(),
            (ck/"trainer_state.json").exists(),
            (ck/"optimizer.pt").exists(),
        ))
    return rows

print("\n[RUNS]")
for s,w,t,o in scan(RUN):
    print(f"  {s:>4}  weights={str(w):5}  state={str(t):5}  opt={str(o):5}")

SENTRY = DRIVE/"sentry_backups"/RUN.name
if SENTRY.exists():
    print("\n[SENTRY]")
    for s,w,t,o in scan(SENTRY):
        print(f"  {s:>4}  weights={str(w):5}  state={str(t):5}  opt={str(o):5}")
else:
    print("\n[SENTRY] none")

[LOSS] step=550  loss=3.2604

[RUNS]
   100  weights=True   state=True   opt=True 
   200  weights=True   state=True   opt=True 
   300  weights=True   state=True   opt=True 
   400  weights=False  state=True   opt=True 
   500  weights=False  state=True   opt=True 
   600  weights=False  state=True   opt=True 
   700  weights=False  state=True   opt=True 

[SENTRY]
   300  weights=True   state=True   opt=True 
   400  weights=False  state=True   opt=True 
   500  weights=False  state=True   opt=True 
   600  weights=False  state=False  opt=False
   700  weights=False  state=True   opt=True 


clean drive

In [None]:
# Light flush to help sync small files
import os, time, pathlib
RUN = max((pathlib.Path("/content/drive/MyDrive/wake2vec")/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
(open(RUN/"_touch.sync","w")).write(str(time.time()))
os.sync()
print("[SYNC] touched + sync hinted")

[SYNC] touched + sync hinted


save if trainer is in scope

In [None]:
try:
    trainer.save_model()
    if hasattr(trainer, "_save_checkpoint"):
        trainer._save_checkpoint(model=trainer.model, trial=None)
    print("[TRAINER] save requested")
except NameError:
    print("[TRAINER] No 'trainer' object in this notebook; use the mirror cell instead.")

[TRAINER] No 'trainer' object in this notebook; use the mirror cell instead.


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
print("Drive mounted.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted.


In [None]:
# VERBOSE mirror of latest full checkpoint (+ metrics) to sentry_backups
import pathlib, shutil, time, os

DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
CANDIDATES = [pathlib.Path("/content/runs"), DRIVE/"runs"]

def latest_run():
    runs = []
    for root in CANDIDATES:
        if root.exists():
            for p in root.glob("t4_*"):
                try:
                    runs.append((p.stat().st_mtime, p))
                except FileNotFoundError:
                    pass
    if not runs: return None
    runs.sort(reverse=True)
    return runs[0][1]

def latest_ckpt_with_weights(run):
    if not run: return None
    cks = sorted(run.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1]), reverse=True)
    for ck in cks:
        if (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists():
            return ck
    return None

RUN = latest_run()
print("[INFO] Active run:", RUN if RUN else "none")
CK  = latest_ckpt_with_weights(RUN)
print("[INFO] Latest full ckpt:", CK if CK else "none")

if CK is None:
    print("[SENTRY] No checkpoint with weights yet — will retry after next save.")
else:
    SENTRY = DRIVE/"sentry_backups"/RUN.name
    SENTRY.mkdir(parents=True, exist_ok=True)
    DST = SENTRY/CK.name

    src_mtime = time.ctime(CK.stat().st_mtime)
    print(f"[SENTRY] Source: {CK} (mtime {src_mtime})")
    if DST.exists():
        # check if dest is older/stale by file count or mtime
        src_files = sum(1 for _ in CK.rglob("*"))
        dst_files = sum(1 for _ in DST.rglob("*"))
        print(f"[SENTRY] Already exists: {DST} (files src={src_files} dst={dst_files})")
        if dst_files < src_files:
            print("[SENTRY] Detected partial mirror; refreshing…")
            shutil.rmtree(DST)
            shutil.copytree(CK, DST)
            print("[SENTRY] Re-mirrored:", DST)
        else:
            print("[SENTRY] Mirror up-to-date.")
    else:
        shutil.copytree(CK, DST)
        print("[SENTRY] Mirrored:", DST)

    # Mirror metrics verbosely
    msrc = RUN/"metrics"
    mdst = SENTRY/"metrics"
    mdst.mkdir(parents=True, exist_ok=True)
    copied = 0
    if msrc.exists():
        for f in msrc.glob("*.json"):
            shutil.copy2(f, mdst/f.name)
            copied += 1
    print(f"[SENTRY] Metrics mirrored → {mdst} ({copied} files)")

[INFO] Active run: /content/drive/MyDrive/wake2vec/runs/t4_1762376560
[INFO] Latest full ckpt: /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-300
[SENTRY] Source: /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-300 (mtime Wed Nov  5 22:45:46 2025)
[SENTRY] Already exists: /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/checkpoint-300 (files src=12 dst=12)
[SENTRY] Mirror up-to-date.
[SENTRY] Metrics mirrored → /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560/metrics (1 files)


In [None]:
import pathlib
BASE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN = BASE/"runs"/"t4_1762376560"  # adjust if different
SENTRY = BASE/"sentry_backups"/"t4_1762376560"

def audit(root):
    print(f"\n[{root}]")
    for ck in sorted(root.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1])):
        step = int(ck.name.split("-")[-1])
        w = (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()
        t = (ck/"trainer_state.json").exists()
        o = (ck/"optimizer.pt").exists()
        print(f"{step:>5}  weights={w:<5}  state={t:<5}  opt={o:<5}  → {ck.name}")

if RUN.exists():   audit(RUN)
if SENTRY.exists(): audit(SENTRY)


[/content/drive/MyDrive/wake2vec/runs/t4_1762376560]
  100  weights=1      state=1      opt=1      → checkpoint-100
  200  weights=1      state=1      opt=1      → checkpoint-200
  300  weights=1      state=1      opt=1      → checkpoint-300
  400  weights=0      state=1      opt=1      → checkpoint-400
  500  weights=0      state=1      opt=1      → checkpoint-500
  600  weights=0      state=1      opt=1      → checkpoint-600
  700  weights=0      state=1      opt=1      → checkpoint-700

[/content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560]
  300  weights=1      state=1      opt=1      → checkpoint-300
  400  weights=0      state=1      opt=1      → checkpoint-400
  500  weights=0      state=1      opt=1      → checkpoint-500
  600  weights=0      state=0      opt=0      → checkpoint-600
  700  weights=0      state=1      opt=1      → checkpoint-700


roll it from 750

In [1]:
import os
from pathlib import Path
from google.colab import drive

drive.mount('/content/drive', force_remount=False)

# Path config
DRIVE = Path("/content/drive/MyDrive/wake2vec")
RUN_ID = "t4_1762376560"
DATASETS = Path("/content/datasets")
LOCAL_RUN = Path("/content/runs") / RUN_ID
SENTRY = DRIVE / "sentry_backups" / RUN_ID
RESUME_FROM = DRIVE / "runs" / RUN_ID / "checkpoint-750-rebuilt"

# directories
LOCAL_RUN.mkdir(parents=True, exist_ok=True)
SENTRY.mkdir(parents=True, exist_ok=True)

# Training config
FINAL_TARGET = 1300
LAST_STEP = 750
TARGET = min(FINAL_TARGET, LAST_STEP + 550)

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

print("[PATH CONFIGURATION]")
print(f"  RUN_ID: {RUN_ID}")
print(f"  RESUME_FROM: {RESUME_FROM}")
print(f"  LOCAL_RUN: {LOCAL_RUN}")
print(f"  SENTRY: {SENTRY}")
print(f"  DATASETS: {DATASETS}")
print(f"  TARGET: {LAST_STEP} → {TARGET}")

Mounted at /content/drive
[PATH CONFIGURATION]
  RUN_ID: t4_1762376560
  RESUME_FROM: /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt
  LOCAL_RUN: /content/runs/t4_1762376560
  SENTRY: /content/drive/MyDrive/wake2vec/sentry_backups/t4_1762376560
  DATASETS: /content/datasets
  TARGET: 750 → 1300


In [2]:
# Dependency Pin & Compat Patches
import sys
import subprocess
import importlib

def pin(pkg, ver):
    """Pin package to specific version"""
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
        print(f"[OK] {pkg} {m.__version__}")
    except Exception:
        print(f"[PIN] {pkg}=={ver}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg} {m.__version__}")

pin("transformers", "4.57.1")
pin("accelerate", "1.2.1")
pin("datasets", "2.21.0")

# Apply unwrap_model compatibility patch
import accelerate
if not hasattr(accelerate.Accelerator, "_w2v_patched"):
    _orig = accelerate.Accelerator.unwrap_model
    def _shim(self, model, *args, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *args, **kw)
    accelerate.Accelerator.unwrap_model = _shim
    accelerate.Accelerator._w2v_patched = True
    print("[PATCH] unwrap_model compatibility shim active")

[OK] transformers 4.57.1
[OK] accelerate 1.2.1
[OK] datasets 2.21.0
[PATCH] unwrap_model compatibility shim active


In [9]:
import gc, torch, os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb=64"

for k in ["trainer","model","tok","optimizer","scheduler","dc","train_ds","valid_ds"]:
    if k in globals():
        try: del globals()[k]
        except: pass
gc.collect(); torch.cuda.empty_cache()
try: torch.cuda.ipc_collect()
except: pass

In [10]:
# Model and Tok
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer

# Clear GPU
torch.cuda.empty_cache()
gc.collect()

# Load tok
tok = AutoTokenizer.from_pretrained(str(RESUME_FROM), use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(
    str(RESUME_FROM),
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True,
)

# Config model
model.config.use_cache = False
model.config.attn_implementation = "eager"

# freeze all except embeddings
for p in model.parameters():
    p.requires_grad = False

emb = model.get_input_embeddings()
emb.weight.requires_grad = True

# Tie output head to input embeds
with torch.no_grad():
    model.get_output_embeddings().weight = emb.weight

model.train()

# Verification
print("[MODEL LOADED]")
print(f"  Tied weights: {'OK' if emb.weight.data_ptr() == model.get_output_embeddings().weight.data_ptr() else 'NO'}")
print(f"  Embedding trainable: {emb.weight.requires_grad}")
print(f"  Pad token ID: {tok.pad_token_id}")

import collections
devs = collections.Counter(p.device.type for p in model.parameters())
print(f"  Device distribution: {dict(devs)}")

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[MODEL LOADED]
  Tied weights: OK
  Embedding trainable: True
  Pad token ID: 2
  Device distribution: {'cuda': 200}


In [11]:
# Callbacks
import json
import time
import shutil
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerState

class EvalEveryNSteps(TrainerCallback):
    """Trigger evaluation at fixed step intervals"""
    def __init__(self, n=200):
        self.n = n

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        if s and s % self.n == 0:
            control.should_log = True
            control.should_evaluate = True

class SentryMirror(TrainerCallback):
    """Mirror checkpoints to backup directory on Drive"""
    def on_save(self, args, state, control, **kw):
        try:
            out = Path(args.output_dir)
            cks = [p for p in out.glob("checkpoint-*") if p.is_dir()]
            if not cks:
                return

            def step_of(p):
                try:
                    return int(p.name.split("-")[-1])
                except:
                    return -1

            ck = max(cks, key=step_of)
            has_weights = (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()

            if not has_weights:
                return

            dst = SENTRY / ck.name
            if not dst.exists():
                shutil.copytree(ck, dst)
                print(f"[SENTRY] mirrored {ck.name}")

            # Mirror metrics
            msrc = out / "metrics"
            if msrc.exists():
                (SENTRY/"metrics").mkdir(parents=True, exist_ok=True)
                for f in msrc.glob("*.json"):
                    shutil.copy2(f, SENTRY/"metrics"/f.name)

        except Exception as e:
            print(f"[SENTRY] mirror failed: {e}")

class EmbeddingSnap(TrainerCallback):
    """Save embedding snapshots at regular intervals"""
    def __init__(self, every=50):
        self.every = every
        (DRIVE/"emb_snaps"/RUN_ID).mkdir(parents=True, exist_ok=True)

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        if s and s % self.every == 0:
            try:
                E = model.get_input_embeddings().weight.detach().cpu()
                out_path = (DRIVE/"emb_snaps"/RUN_ID) / f"emb_step{s:04d}.pt"
                torch.save(E, out_path)

                # Write heartbeat metadata
                heartbeat = {
                    "run_id": RUN_ID,
                    "step": s,
                    "rows": int(E.size(0)),
                    "dim": int(E.size(1)),
                    "ts": time.time()
                }
                (out_path.parent/"heartbeat.json").write_text(
                    json.dumps(heartbeat, indent=2))

                print(f"[SNAP] embeddings saved to {out_path.name}")

            except Exception as e:
                print(f"[SNAP] failed: {e}")

class StepTimer(TrainerCallback):
    """Monitor training throughput in steps per second"""
    def __init__(self, every=10):
        self.prev = None
        self.t = None
        self.every = every

    def on_step_end(self, args, state, control, **kw):
        s = int(state.global_step or 0)
        now = time.time()

        if self.prev is not None and s > self.prev and s % self.every == 0:
            dt = now - self.t
            print(f"[{s:4d}] ~{dt/self.every:.2f}s/step (last {self.every})")

        self.prev, self.t = s, now

print("[CALLBACKS CONFIGURED]")

[CALLBACKS CONFIGURED]


In [1]:
# Pin and patch FIRST
import sys, subprocess, importlib

def pin(pkg, ver):
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
        print(f"[OK] {pkg} {m.__version__}")
    except Exception:
        print(f"[PIN] {pkg}=={ver}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg} {m.__version__}")

pin("transformers", "4.57.1")
pin("accelerate", "1.2.1")
pin("datasets", "2.21.0")

import accelerate
if not hasattr(accelerate.Accelerator, "_w2v_patched"):
    _orig = accelerate.Accelerator.unwrap_model
    def _shim(self, model, *args, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *args, **kw)
    accelerate.Accelerator.unwrap_model = _shim
    accelerate.Accelerator._w2v_patched = True
    print("[PATCH] active")

[OK] transformers 4.57.1
[OK] accelerate 1.2.1
[OK] datasets 2.21.0
[PATCH] active


In [2]:
# DATASETS
from pathlib import Path
import shutil

# mount is idempotent in Colab
try:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)
except Exception:
    pass

DRIVE = Path("/content/drive/MyDrive/wake2vec")
DATASETS = Path("/content/datasets"); DATASETS.mkdir(parents=True, exist_ok=True)

def ensure_dir(src: Path, dst: Path):
    assert src.exists(), f"Missing dataset at {src}"
    if not dst.exists():
        print(f"[DATA] copying {src} → {dst}")
        shutil.copytree(src, dst)
    else:
        print(f"[DATA] already local:", dst)

# try to copy to local (fast path)
src_train = DRIVE/"datasets"/"train_ds"
src_valid = DRIVE/"datasets"/"valid_ds"
dst_train = DATASETS/"train_ds"
dst_valid = DATASETS/"valid_ds"
ensure_dir(src_train, dst_train)
ensure_dir(src_valid, dst_valid)

# load with a Drive fallback just in case
try:
    train_ds_path = str(dst_train if dst_train.exists() else src_train)
    valid_ds_path = str(dst_valid if dst_valid.exists() else src_valid)
    print("[DATA] train_ds:", train_ds_path)
    print("[DATA] valid_ds:", valid_ds_path)
except Exception as e:
    raise RuntimeError(f"Dataset path resolution failed: {e}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[DATA] already local: /content/datasets/train_ds
[DATA] already local: /content/datasets/valid_ds
[DATA] train_ds: /content/datasets/train_ds
[DATA] valid_ds: /content/datasets/valid_ds


In [None]:
# OPTIMIZED 750-1300 (Memory-Safe for T4, with grad checkpointing)
import os, gc, json, time, shutil, collections, torch
from pathlib import Path
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, TrainingArguments,
                          Trainer, TrainerCallback)
from transformers.trainer_callback import TrainerState

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Paths
RUN_ID = "t4_1762376560"
RESUME_FROM = Path("/content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt")
DRIVE = Path("/content/drive/MyDrive/wake2vec")
LOCAL_RUN = Path("/content/runs")/RUN_ID
LOCAL_RUN.mkdir(parents=True, exist_ok=True)
DATASETS = Path("/content/datasets")
train_ds_path = DATASETS / "train_ds"
valid_ds_path = DATASETS / "valid_ds"

# Memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

# Load tokenizer
tok = AutoTokenizer.from_pretrained(str(RESUME_FROM), use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

# Load model with device_map auto
model = AutoModelForCausalLM.from_pretrained(
    str(RESUME_FROM),
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True,
)

model.config.use_cache = False
model.config.attn_implementation = "eager"  # Safer than SDPA

# Freeze and tie
for p in model.parameters():
    p.requires_grad = False
emb = model.get_input_embeddings()
emb.weight.requires_grad = True
with torch.no_grad():
    model.get_output_embeddings().weight = emb.weight

model.train()
print("[TIED]", "OK" if emb.weight.data_ptr()==model.get_output_embeddings().weight.data_ptr() else "NO")

# Load data with smaller valid set
train_ds = load_from_disk(str(train_ds_path))
valid_all = load_from_disk(str(valid_ds_path))
valid_ds = valid_all.select(range(min(300, len(valid_all))))  # Reduced from 1000

# Truncating collator
base_dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
MAX_LEN = 320

class TruncatingCollator:
    def __init__(self, base, max_len=MAX_LEN):
        self.base, self.max_len = base, max_len
    def __call__(self, feats):
        out = self.base(feats)
        for k,v in list(out.items()):
            if isinstance(v, torch.Tensor) and v.dim()==2:
                out[k] = v[:, :self.max_len]
        return out

dc = TruncatingCollator(base_dc, MAX_LEN)

# Callbacks
class EvalEveryNSteps(TrainerCallback):
    def __init__(self, n=200): self.n=n
    def on_step_end(self, args, state, control, **kw):
        s=int(state.global_step or 0)
        if s and s%self.n==0:
            control.should_log=True
            control.should_evaluate=True

class SentryMirror(TrainerCallback):
    def on_save(self, args, state, control, **kw):
        try:
            out = Path(args.output_dir)
            cks = [p for p in out.glob("checkpoint-*") if p.is_dir()]
            if not cks: return
            ck = max(cks, key=lambda p: int(p.name.split("-")[-1]))
            has = (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()
            if not has: return
            dst = DRIVE/"sentry_backups"/RUN_ID/ck.name
            if not dst.exists():
                shutil.copytree(ck, dst)
                print(f"[SENTRY] {ck.name}")
        except Exception as e:
            print(f"[SENTRY] fail: {e}")

class EmbeddingSnap(TrainerCallback):
    def __init__(self, every=50): self.every=every
    def on_step_end(self, args, state, control, **kw):
        s=int(state.global_step or 0)
        if s and s%self.every==0:
            try:
                E = model.get_input_embeddings().weight.detach().cpu()
                out = (DRIVE/"emb_snaps"/RUN_ID)/f"emb_step{s:04d}.pt"
                out.parent.mkdir(parents=True, exist_ok=True)
                torch.save(E, out)
                print(f"[SNAP] {s}")
            except Exception as e:
                print(f"[SNAP] fail: {e}")

class StepTimer(TrainerCallback):
    def __init__(self, every=10):
        self.prev=None
        self.t=None
        self.every=every
    def on_step_end(self, args, state, control, **kw):
        s=int(state.global_step or 0)
        now=time.time()
        if self.prev and s>self.prev and s%self.every==0:
            dt=now-self.t
            print(f"[{s:4d}] {dt/self.every:.1f}s/step")
        self.prev, self.t = s, now

# Training config
FINAL_TARGET = 1300
LAST_STEP = 750
TARGET = min(FINAL_TARGET, LAST_STEP + 550)

args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=25,
    save_strategy="steps",
    save_steps=75,
    save_total_limit=2,  # Reduced
    gradient_checkpointing=True,  # ENABLE
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,  # Changed
    dataloader_pin_memory=False,  # Changed
    dataloader_persistent_workers=False,  # Changed
    report_to=["none"],
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
    callbacks=[EvalEveryNSteps(200), SentryMirror(), EmbeddingSnap(50), StepTimer(10)],
)

# Ensure trainer state
ts = RESUME_FROM/"trainer_state.json"
if not ts.exists():
    ts.write_text(json.dumps({"global_step":LAST_STEP, "max_steps":TARGET, "log_history":[]}, indent=2))
    print("[RESUME] wrote state")

print(f"[GO] 750→{TARGET} | MAX_LEN={MAX_LEN} | GC=ON")
t0 = time.time()
trainer.train(resume_from_checkpoint=str(RESUME_FROM))
print(f"[DONE] {(time.time()-t0)/60:.1f}min")

In [6]:
# FORCE CLEAN GPU MEMORY
import gc
import torch

for name in list(globals().keys()):
    if 'model' in name.lower() or 'trainer' in name.lower():
        try:
            del globals()[name]
        except:
            pass

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

print("[MEMORY] GPU cleaned")

[MEMORY] GPU cleaned


In [None]:
import gc
import torch

for name in list(globals().keys()):
    if 'model' in name.lower() or 'trainer' in name.lower():
        try:
            del globals()[name]
        except:
            pass

gc.collect()
torch.cuda.empty_cache()
print("[MEMORY] Cleaned")

In [None]:
# Save directly to Drive
import os, gc, torch, time, shutil
from pathlib import Path
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, TrainingArguments,
                          Trainer, TrainerCallback)

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

RUN_ID = "t4_1762376560"
RESUME_FROM = Path("/content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt")
DRIVE = Path("/content/drive/MyDrive/wake2vec")
OUTPUT_DIR = DRIVE / "runs" / RUN_ID
DATASETS = Path("/content/datasets")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

tok = AutoTokenizer.from_pretrained(str(RESUME_FROM), use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(
    str(RESUME_FROM),
    torch_dtype=torch.float32,
    device_map=None,
    low_cpu_mem_usage=True,
)
model.to("cuda")
model.config.use_cache = False
model.config.attn_implementation = "eager"

for p in model.parameters():
    p.requires_grad = False
emb = model.get_input_embeddings()
emb.weight.requires_grad = True
with torch.no_grad():
    model.get_output_embeddings().weight = emb.weight
model.train()

print("[TIED]", "OK")

train_ds = load_from_disk(str(DATASETS/"train_ds"))
valid_ds = load_from_disk(str(DATASETS/"valid_ds"))

base_dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
MAX_LEN = 256

class TruncatingCollator:
    def __init__(self, base, max_len):
        self.base, self.max_len = base, max_len
    def __call__(self, feats):
        out = self.base(feats)
        for k,v in list(out.items()):
            if isinstance(v, torch.Tensor) and v.dim()==2:
                out[k] = v[:, :self.max_len]
        return out

dc = TruncatingCollator(base_dc, MAX_LEN)

class StepTimer(TrainerCallback):
    def __init__(self):
        self.prev=None
        self.t=None
    def on_step_end(self, args, state, control, **kw):
        s=int(state.global_step or 0)
        now=time.time()
        if self.prev and s>self.prev and s%10==0:
            print(f"[{s:4d}] {(now-self.t)/10:.1f}s/step")
        self.prev, self.t = s, now

TARGET = 1300

args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=5,
    gradient_checkpointing=True,
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
    callbacks=[StepTimer()],
)

print(f"[GO] 750→{TARGET} | DRIVE: {OUTPUT_DIR}")
t0 = time.time()
trainer.train(resume_from_checkpoint=str(RESUME_FROM))
print(f"[DONE] {(time.time()-t0)/60:.1f}min")

In [15]:
# CORRECTED - Save directly to Drive
import os, gc, torch, time, shutil
from pathlib import Path
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, TrainingArguments,
                          Trainer, TrainerCallback)

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# CRITICAL FIX: Save directly to DRIVE
RUN_ID = "t4_1762376560"
RESUME_FROM = Path("/content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-750-rebuilt")
DRIVE = Path("/content/drive/MyDrive/wake2vec")
OUTPUT_DIR = DRIVE / "runs" / RUN_ID
DATASETS = Path("/content/datasets")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Clean memory
for name in list(globals().keys()):
    if 'model' in name.lower() or 'trainer' in name.lower():
        try:
            del globals()[name]
        except:
            pass
gc.collect()
torch.cuda.empty_cache()

print("[MEMORY] GPU cleaned")

# Load model (same as before)
tok = AutoTokenizer.from_pretrained(str(RESUME_FROM), use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token or "</s>"

model = AutoModelForCausalLM.from_pretrained(
    str(RESUME_FROM),
    torch_dtype=torch.float32,
    device_map=None,
    low_cpu_mem_usage=True,
)
model.to("cuda")
model.config.use_cache = False
model.config.attn_implementation = "eager"

for p in model.parameters():
    p.requires_grad = False
emb = model.get_input_embeddings()
emb.weight.requires_grad = True
with torch.no_grad():
    model.get_output_embeddings().weight = emb.weight
model.train()

print("[TIED]", "OK" if emb.weight.data_ptr()==model.get_output_embeddings().weight.data_ptr() else "NO")

# Data
train_ds = load_from_disk(str(DATASETS/"train_ds"))
valid_ds = load_from_disk(str(DATASETS/"valid_ds"))
print(f"[DATA] train={len(train_ds)} valid={len(valid_ds)}")

base_dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
MAX_LEN = 256

class TruncatingCollator:
    def __init__(self, base, max_len):
        self.base, self.max_len = base, max_len
    def __call__(self, feats):
        out = self.base(feats)
        for k,v in list(out.items()):
            if isinstance(v, torch.Tensor) and v.dim()==2:
                out[k] = v[:, :self.max_len]
        return out

dc = TruncatingCollator(base_dc, MAX_LEN)

# Minimal callbacks
class StepTimer(TrainerCallback):
    def __init__(self):
        self.prev=None
        self.t=None
    def on_step_end(self, args, state, control, **kw):
        import time
        s=int(state.global_step or 0)
        now=time.time()
        if self.prev and s>self.prev and s%10==0:
            print(f"[{s:4d}] {(now-self.t)/10:.1f}s/step")
        self.prev, self.t = s, now

TARGET = 1300
LAST_STEP = 750

# CRITICAL: output_dir is now ON DRIVE
args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=5,
    gradient_checkpointing=True,
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
    callbacks=[StepTimer()],
)

print(f"[GO] 750→{TARGET} | Saving to DRIVE: {OUTPUT_DIR}")
import time
t0 = time.time()
trainer.train(resume_from_checkpoint=str(RESUME_FROM))
print(f"[DONE] {(time.time()-t0)/60:.1f}min")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[MEMORY] GPU cleaned


NameError: name 'AutoModelForCausalLM' is not defined

In [None]:
# Resume P1 from best valid ckpt (300) → 750
import pathlib, shutil, json, time, os, torch
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, TrainingArguments, Trainer, TrainerCallback)

DRIVE = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN_ID = max((DRIVE/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime).name
DRIVE_RUN = DRIVE/"runs"/RUN_ID
LOCAL_RUN = pathlib.Path("/content/runs")/RUN_ID
SENTRY    = DRIVE/"sentry_backups"/RUN_ID
LOCAL_RUN.mkdir(parents=True, exist_ok=True); SENTRY.mkdir(parents=True, exist_ok=True)

def list_ckpts(root):
    if not root.exists(): return []
    return sorted([p for p in root.glob("checkpoint-*") if p.is_dir()],
                  key=lambda p: int(p.name.split("-")[-1]), reverse=True)

def has_weights(ck):
    return (ck/"model.safetensors").exists() or (ck/"pytorch_model.bin").exists()

# pick newest valid (weights) from runs or sentry (your table shows 300 is the last valid)
cands = list_ckpts(DRIVE_RUN) + list_ckpts(SENTRY)
GOOD = next((p for p in cands if has_weights(p)), None)
assert GOOD is not None, "No valid checkpoint with weights."
last_step = int(GOOD.name.split("-")[-1])
print(f"[RESUME] {RUN_ID} from {GOOD.name} (last_step={last_step})")

# seed local output_dir from that ckpt (fast, prevents partial Drive saves)
if not (LOCAL_RUN/GOOD.name).exists():
    shutil.copytree(GOOD, LOCAL_RUN/GOOD.name)
    print("[LOCAL] seeded:", (LOCAL_RUN/GOOD.name))

# tokenizer + model
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(str(GOOD), use_fast=True)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(str(GOOD), torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
    model.get_output_embeddings().weight = model.get_input_embeddings().weight  # tie head

# datasets
train_ds = load_from_disk(str(DRIVE/"datasets"/"train_ds"))
valid_ds = load_from_disk(str(DRIVE/"datasets"/"valid_ds")).select(range(min(1000, len(load_from_disk(str(DRIVE/'datasets'/'valid_ds'))))))
dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

TARGET = 1300

# Callbacks
from transformers import TrainerCallback

class EvalEveryNSteps(TrainerCallback):
    def __init__(self, n=200): self.n=n
    def on_step_end(self, args, state, control, **kw):
        if state.global_step and state.global_step % self.n == 0:
            control.should_log = True
            control.should_evaluate = True

class SentryMirror(TrainerCallback):
    def on_save(self, args, state, control, **kw):
        try:
            ck = max(LOCAL_RUN.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[-1]))
            dst = SENTRY/ck.name
            if not dst.exists():
                shutil.copytree(ck, dst)
                print(f"[SENTRY] mirrored {ck.name}")
            # metrics
            mdst = SENTRY/"metrics"; mdst.mkdir(parents=True, exist_ok=True)
            msrc = LOCAL_RUN/"metrics"
            if msrc.exists():
                for f in msrc.glob("*.json"): shutil.copy2(f, mdst/f.name)
            os.sync()
        except Exception as e:
            print("[SENTRY] mirror failed:", e)

class EmbeddingSnap(TrainerCallback):
    def __init__(self, every=50):
        self.every = every
        (DRIVE/"emb_snaps"/RUN_ID).mkdir(parents=True, exist_ok=True)
    def on_step_end(self, args, state, control, **kw):
        if state.global_step and state.global_step % self.every == 0:
            try:
                emb = model.get_input_embeddings().weight.detach().cpu()
                path = DRIVE/"emb_snaps"/RUN_ID/f"emb_step{int(state.global_step):04d}.pt"
                torch.save(emb, path)
                (DRIVE/"emb_snaps"/RUN_ID/"heartbeat.json").write_text(json.dumps(
                    {"step": int(state.global_step), "rows": int(emb.size(0)), "dim": int(emb.size(1)), "ts": time.time()}, indent=2))
                print(f"[SNAP] embeddings → {path.name}")
            except Exception as e:
                print("[SNAP] failed:", e)

# freq at 50 until 700, then 100
save_steps = 50 if last_step < 700 else 100

from transformers import TrainingArguments
args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=50,
    save_strategy="steps", save_steps=save_steps, save_total_limit=20,
    gradient_checkpointing=True,
    fp16=False, bf16=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

from transformers import Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc,
    callbacks=[EvalEveryNSteps(200), SentryMirror(), EmbeddingSnap(50)],
)

trainer.train(resume_from_checkpoint=str(LOCAL_RUN/GOOD.name))
# Finalize
trainer.save_model(str(LOCAL_RUN/"checkpoint-final")); tok.save_pretrained(str(LOCAL_RUN/"checkpoint-final"))
# Mirror final
dst = SENTRY/"checkpoint-final"
if dst.exists(): shutil.rmtree(dst)
shutil.copytree(LOCAL_RUN/"checkpoint-final", dst)
print("[SENTRY] mirrored checkpoint-final")
print("[DONE] Reached", TARGET)

[RESUME] t4_1762376560 from checkpoint-300 (last_step=300)
[LOCAL] seeded: /content/runs/t4_1762376560/checkpoint-300


`torch_dtype` is deprecated! Use `dtype` instead!
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /content/drive/MyDrive/wake2vec/runs/t4_1762376560/checkpoint-300 and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The model is already on multiple devices. Skipping the move to device specified in `args`.
There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
	save_steps: 50 (from args) != 100 (from trainer_state.json)


Step,Training Loss,Validation Loss
350,5.4883,
400,5.3082,6.273335
450,4.7776,No Log
500,4.0891,No Log
550,3.2604,No Log
600,2.4259,7.170007
650,1.5541,No Log
700,0.7983,No Log
750,0.3192,No Log


[SNAP] embeddings → emb_step0350.pt
[SNAP] embeddings → emb_step0400.pt
[SNAP] embeddings → emb_step0450.pt
[SNAP] embeddings → emb_step0500.pt
[SNAP] embeddings → emb_step0550.pt
[SNAP] embeddings → emb_step0600.pt
[SNAP] embeddings → emb_step0650.pt
[SNAP] embeddings → emb_step0700.pt
[SNAP] embeddings → emb_step0750.pt


In [None]:
import torch
print("CUDA:", torch.cuda.is_available(), "| device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")

CUDA: True | device: Tesla T4


In [None]:
# Pin + shim
import sys, subprocess, importlib, os

def pin(pkg, ver):
    try:
        m = importlib.import_module(pkg)
        assert m.__version__ == ver
        print(f"[OK] {pkg} {m.__version__}")
    except Exception:
        print(f"[PIN] {pkg}=={ver}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{pkg}=={ver}", "-q"])
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg} {m.__version__}")

pin("transformers", "4.57.1")
pin("accelerate",   "1.2.1")
pin("datasets",     "2.21.0")

# unwrap_model shim (ignore keep_torch_compile kw)
import accelerate
if not hasattr(accelerate.Accelerator, "_w2v_patched"):
    _orig = accelerate.Accelerator.unwrap_model
    def _shim(self, model, *args, **kw):
        kw.pop("keep_torch_compile", None)
        return _orig(self, model, *args, **kw)
    accelerate.Accelerator.unwrap_model = _shim
    accelerate.Accelerator._w2v_patched = True
    print("[PATCH] unwrap_model shim active")

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

[OK] transformers 4.57.1
[PIN] accelerate==1.2.1
[OK] accelerate 1.11.0
[PIN] datasets==2.21.0
[OK] datasets 4.0.0
[PATCH] unwrap_model shim active


'expandable_segments:True'

In [None]:
# Force resume from rebuilt 750
RUN_ID = "t4_1762376560"
RESUME_FROM = f"/content/drive/MyDrive/wake2vec/runs/{RUN_ID}/checkpoint-750-rebuilt"
print("[RESUME FROM]", RESUME_FROM)

FINAL_TARGET = 1300
last_step = 750
TARGET = min(FINAL_TARGET, last_step + 300)
save_steps = 75 if last_step < 1000 else 100
print(f"[PLAN] last={last_step} → target={TARGET} | save_steps={save_steps}")

RESUME_FROM = f"/content/drive/MyDrive/wake2vec/runs/{RUN_ID}/checkpoint-750-rebuilt"
print("[RESUME FROM]", RESUME_FROM)

for p1 eval

In [None]:
print("TRAIN output_dir:", trainer.args.output_dir)
print("RUN_ID:", pathlib.Path(trainer.args.output_dir).name)
print("global_step:", trainer.state.global_step)

In [None]:
import inspect
from transformers import TrainingArguments
'evaluation_strategy' in inspect.signature(TrainingArguments.__init__).parameters

In [None]:
from transformers import TrainerCallback

class EvalEveryNSteps(TrainerCallback):
    def __init__(self, n=200): self.n=n
    def on_step_end(self, args, state, control, **kw):
        if state.global_step and (state.global_step % self.n == 0):
            control.should_save = True         # ensure state gets flushed
            control.should_log = True
            control.should_evaluate = True

In [None]:
# P1 finalise overlap@5, norm stats, and a loss plot
import json, numpy as np, pathlib, matplotlib.pyplot as plt

DRIVE_ROOT = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUNS = sorted((DRIVE_ROOT/"runs").glob("t4_*"), key=lambda p: p.stat().st_mtime)
RUN_DIR = pathlib.Path("/content/runs")/RUNS[-1].name
METRICS_DIR = RUN_DIR/"metrics"
PLOTS_DIR = RUN_DIR/"plots"; PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# Load current embeddings
from transformers import AutoModelForCausalLM
BASE_CKPT = RUN_DIR/"checkpoint-final"
model = AutoModelForCausalLM.from_pretrained(str(BASE_CKPT), torch_dtype="float32", device_map=None)
E_post = model.get_input_embeddings().weight.detach().cpu().numpy()

# Optional composed init + ids
E_COMP = DRIVE_ROOT/"E_comp.npy"
NEW_IDS = DRIVE_ROOT/"new_ids.npy"
has_comp = E_COMP.exists() and NEW_IDS.exists()

def topk_overlap(a, b, k=5):
    import numpy as np
    from numpy.linalg import norm
    a = a / (norm(a, axis=1, keepdims=True)+1e-9)
    b = b / (norm(b, axis=1, keepdims=True)+1e-9)
    sims = a @ b.T
    top_a = np.argsort(-sims, axis=1)[:, :k]
    top_b = np.argsort(-sims, axis=1)[:, :k]
    inter = np.array([len(set(top_a[i]) & set(top_b[i])) for i in range(a.shape[0])])
    return inter.mean()

report = {}

if has_comp:
    E_comp = np.load(E_COMP)
    new_ids = np.load(NEW_IDS)
    E_post_new = E_post[new_ids]
    overlap5 = topk_overlap(E_comp, E_post_new, k=5)
    # Norm drift
    from numpy.linalg import norm
    dn = (norm(E_post_new, axis=1) - norm(E_comp, axis=1)).mean()
    report.update({"overlap_at_5": float(overlap5), "mean_delta_norm": float(dn), "n_new": int(len(new_ids))})
else:
    # Fallback
    from numpy.linalg import norm
    norms = norm(E_post, axis=1)
    report.update({"post_mean_norm": float(norms.mean()), "post_std_norm": float(norms.std()), "n_vocab": int(E_post.shape[0])})

# JSON
(METRICS_DIR/"p1_summary.json").write_text(json.dumps(report, indent=2))
print("[P1 SUMMARY]", json.dumps(report, indent=2))

# Loss plot
import json, glob
state_files = [RUN_DIR/"trainer_state.json", BASE_CKPT/"trainer_state.json"]
state_files = [p for p in state_files if p.exists()]
logs = []
for sf in state_files:
    s = json.loads(sf.read_text())
    logs.extend([d for d in s.get("log_history", []) if "loss" in d])

if logs:
    steps = [d["step"] for d in logs]
    losses = [float(d["loss"]) for d in logs]
    ema = []
    alpha = 0.1
    for i,x in enumerate(losses):
        ema.append(x if i==0 else alpha*x + (1-alpha)*ema[-1])
    plt.figure(figsize=(7,4.5))
    plt.plot(steps, losses, label="loss")
    plt.plot(steps, ema, label="EMA(0.1)")
    plt.title(f"P1 Loss — {RUN_DIR.name}")
    plt.xlabel("step"); plt.ylabel("loss"); plt.grid(True, linewidth=0.3); plt.legend()
    outp = PLOTS_DIR/"p1_loss_curve.png"
    plt.savefig(outp, dpi=140, bbox_inches="tight")
    print("[PLOT]", outp)
else:
    print("[WARN] No trainer_state logs found; skip loss plot.")

In [None]:
# Training Config and Execution
from transformers import TrainingArguments, Trainer

# Training args
args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    max_steps=TARGET,
    learning_rate=5e-4,
    warmup_ratio=0.0,
    optim="adafactor",
    logging_steps=25,
    save_strategy="steps",
    save_steps=75,
    save_total_limit=3,
    gradient_checkpointing=True,
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    dataloader_persistent_workers=False,
    report_to=["none"],
    max_grad_norm=1.0,
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=dc_trunc,
    callbacks=[
        EvalEveryNSteps(200),
        SentryMirror(),
        EmbeddingSnap(50),
        StepTimer(10)
    ],
)

# resume state
ts = RESUME_FROM / "trainer_state.json"
if not ts.exists():
    state_dict = {
        "global_step": LAST_STEP,
        "max_steps": TARGET,
        "log_history": []
    }
    ts.write_text(json.dumps(state_dict, indent=2))
    print(f"[RESUME] wrote trainer_state.json at step {LAST_STEP}")

# training
print(f"[GO] Resuming from step {LAST_STEP} with MAX_LEN={MAX_LEN}")
print(f"     Target: {TARGET} | Gradient checkpointing: ON")

t0 = time.time()
trainer.train(resume_from_checkpoint=str(RESUME_FROM))
elapsed = time.time() - t0

print(f"[COMPLETE] Training finished in {elapsed:.1f}s ({elapsed/60:.1f}min)")