<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2vec_Llama_3_1_8b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Wake2Vec “*F* the Embeddings” (T4 Edn)
This notebook is a colab-friendly, embedding-only finetune pipeline for large decoder LMs (Mistral-7B / Llama-2-13B / Llama-3.1-8B) using a Wake lexicon injection. It adds Joyce-specific tokens, initializes them on a sphere, and trains only the input embedding rows (optionally with a minimal LoRA r=1 on q_proj to satisfy quantized-training rules). The goal is to bend local geometry (neighbors, isotropy) while keeping the rest of the model frozen.








# p2
max_steps: 1500-2500

lr: 5e-4 → 1e-4 (cosine decay)

batch_size: 1

grad_accum: 16

Custom loss = LM_loss + λ₁·attraction + λ₂·repulsion + λ₃·morphological + λ₄·adversarial

guardrail

In [None]:
# NUCLEAR OPT
!pip uninstall -y torch torchvision torchaudio triton bitsandbytes transformers accelerate peft fastai timm

# rr
import os
os.kill(os.getpid(), 9)

Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
Found existing installation: triton 3.4.0
Uninstalling triton-3.4.0:
  Successfully uninstalled triton-3.4.0
[0mFound existing installation: transformers 4.57.0
Uninstalling transformers-4.57.0:
  Successfully uninstalled transformers-4.57.0
Found existing installation: accelerate 1.10.1
Uninstalling accelerate-1.10.1:
  Successfully uninstalled accelerate-1.10.1
Found existing installation: peft 0.17.1
Uninstalling peft-0.17.1:
  Successfully uninstalled peft-0.17.1
Found existing installation: fastai 2.8.4
Uninstalling fastai-2.8.4:
  Successfully uninstalled fastai-2.8.4
Found exis

In [1]:
# Stop TorchAO
import os
os.environ["TRANSFORMERS_NO_TORCHAO"] = "1"

# compat versions
!pip install -q --no-cache-dir \
    torch==2.5.1 \
    triton==3.1.0 \
    bitsandbytes==0.43.3 \
    transformers==4.45.2 \
    accelerate==0.34.2 \
    peft==0.13.2

# Verify
import torch, bitsandbytes as bnb, triton
print("torch:", torch.__version__, "| cuda:", torch.version.cuda)
print("bnb:", bnb.__version__, "| triton:", triton.__version__)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m186.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m906.4/906.4 MB[0m [31m187.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.6/209.6 MB[0m [31m75.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m82.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m61.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.4/324.4 kB[0m [31m103.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m137.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m185.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# stop TorchAO
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from getpass import getpass
from huggingface_hub import login

HF_TOKEN = getpass("Paste your HF token (hidden): ")
login(token=HF_TOKEN, add_to_git_credential=True)

Paste your HF token (hidden): ··········


In [4]:
# Val dataset
from torch.utils.data import Dataset
import torch

class BlockDataset(Dataset):
    """Sliding window dataset for causal LM training."""
    def __init__(self, txt_path, tokenizer, seq_len=512, stride=512):
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokens = tokenizer(text, add_special_tokens=False)['input_ids']
        self.seq_len = seq_len
        self.stride = stride
        self.starts = list(range(0, len(self.tokens) - seq_len + 1, stride))

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        start = self.starts[idx]
        chunk = self.tokens[start:start + self.seq_len]
        return {'input_ids': torch.tensor(chunk, dtype=torch.long)}

from transformers import AutoTokenizer

print("Testing dataset creation...")
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_fast=True)

ds = BlockDataset("/content/FW_TEXT.txt", tok, seq_len=512, stride=512)
print(f"  Dataset size: {len(ds)} blocks")
print(f"  First block shape: {ds[0]['input_ids'].shape}")

print("\nDataset validated")

Testing dataset creation...


Token indices sequence length is longer than the specified maximum sequence length for this model (378828 > 131072). Running this sequence through the model will result in indexing errors


  Dataset size: 739 blocks
  First block shape: torch.Size([512])

Dataset validated


In [5]:
# Helper functions
def read_lines(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

class BlockDataset(Dataset):
    """Sliding window dataset for causal LM training."""
    def __init__(self, txt_path, tokenizer, seq_len=512, stride=512):
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokens = tokenizer(text, add_special_tokens=False)['input_ids']
        self.seq_len = seq_len
        self.stride = stride
        self.starts = list(range(0, len(self.tokens) - seq_len + 1, stride))

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        start = self.starts[idx]
        chunk = self.tokens[start:start + self.seq_len]
        return {'input_ids': torch.tensor(chunk, dtype=torch.long)}

In [6]:
# Pre-train
import os

print("GPU Check:")
print(f"  Device: {torch.cuda.get_device_name(0)}")
print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

print("\nFile Check:")
files = {
    "Wake Lexicon": "/content/wake_lexicon.txt",
    "FW Text": "/content/FW_TEXT.txt"
}
for name, path in files.items():
    exists = os.path.exists(path)
    status = "Found" if exists else "MISSING"
    print(f"  {name}: {status}")
    if exists:
        size = os.path.getsize(path) / 1024
        print(f"    Size: {size:.1f} KB")

GPU Check:
  Device: Tesla T4
  Memory: 15.83 GB
  Allocated: 0.00 GB

File Check:
  Wake Lexicon: Found
    Size: 403.0 KB
  FW Text: Found
    Size: 1358.4 KB


In [7]:
import gc
import torch

# Verify clean slate
torch.cuda.empty_cache()
gc.collect()

print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

GPU memory allocated: 0.00 GB
GPU memory cached: 0.00 GB


In [8]:
# memory footprint
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

MODEL_NAME = "meta-llama/Llama-3.2-3B"

print(f"Testing {MODEL_NAME} load on T4")

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

print("\nLoading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
print(f"  Vocab size: {len(tok)}")

print("\nLoading model with 4-bit quantization...")
torch.cuda.empty_cache()
initial_mem = torch.cuda.memory_allocated(0) / 1e9

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb,
    torch_dtype=torch.float16,
    device_map="auto"
)

loaded_mem = torch.cuda.memory_allocated(0) / 1e9
print(f"  Model loaded: {loaded_mem:.2f} GB")
print(f"  Delta: {loaded_mem - initial_mem:.2f} GB")

# Validate forward pass
print("\nTesting forward pass...")
test_ids = torch.tensor([[1, 2, 3, 4, 5]], device="cuda")
with torch.no_grad():
    out = model(test_ids)
    print(f"  Output shape: {out.logits.shape}")
    print(f"  Memory after forward: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

print(f"\nPeak memory: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")
print("Model validated successfully")

# Cleanup
del model, tok
torch.cuda.empty_cache()
print("Memory cleared for main run")

Testing meta-llama/Llama-3.2-3B load on T4

Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

  Vocab size: 128256

Loading model with 4-bit quantization...


config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

  Model loaded: 2.26 GB
  Delta: 2.26 GB

Testing forward pass...


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


  Output shape: torch.Size([1, 5, 128256])
  Memory after forward: 2.27 GB

Peak memory: 2.32 GB
Model validated successfully
Memory cleared for main run


In [None]:
import os, math, json, random, torch, shutil
from pathlib import Path
from torch.utils.data import Dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
                          TrainingArguments, Trainer, TrainerCallback, set_seed)
from peft import LoraConfig, get_peft_model

SEED=42; set_seed(SEED)
MODEL_NAME = "meta-llama/Llama-3.2-1B"
WAKE_LEX_PATH = "/content/wake_lexicon.txt"
CORPUS_TXT = "/content/finnegans_wake.txt"

# CRITICAL: Save to Drive, not /content
RUN_DIR = Path("/content/drive/MyDrive/wake_llama_P1")
LOCAL_RUN = Path("/content/runs/wake_llama_P1")
SENTRY = RUN_DIR / "sentry_backups"

RUN_DIR.mkdir(parents=True, exist_ok=True)
LOCAL_RUN.mkdir(parents=True, exist_ok=True)
SENTRY.mkdir(parents=True, exist_ok=True)

SEQ_LEN=512; STRIDE=512
MAX_STEPS=1100; LOG_STEPS=20; SAVE_STEPS=200
LR=5e-5
GRAD_ACCUM=8
REPULSION_W=0.0
TARGET_NORM=None
MAX_ROW_NORM=None
REPORT_SAMPLE=1500

# 4-bit quantization
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload=True
)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    max_memory={0: "13GB", "cpu": "30GB"}
)
model.config.use_cache = False
model.config.attn_implementation = "eager"
model.config.tie_word_embeddings = True
if hasattr(model, "tie_weights"): model.tie_weights()

# Frozen PEFT adapter
peft_cfg = LoraConfig(r=1, lora_alpha=1, lora_dropout=0.0,
                      target_modules=["q_proj"], bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, peft_cfg)
for n,p in model.named_parameters(): p.requires_grad=False

# Wake vocab injection
def read_lines(p):
    return [x.strip() for x in open(p, encoding="utf-8") if x.strip()] if os.path.exists(p) else []

wake = read_lines(WAKE_LEX_PATH)
missing = [t for t in wake if tok.convert_tokens_to_ids(t)==tok.unk_token_id]
num_added = tok.add_tokens(missing, special_tokens=False)

old_vocab = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok))
wte = model.get_input_embeddings()
if hasattr(model, "lm_head"): model.lm_head.weight = wte.weight

# Spherical init
with torch.no_grad():
    base = wte.weight[:old_vocab]; dim = base.shape[1]
    std = base.std().item(); base_radius = std * math.sqrt(dim)
    target_radius = TARGET_NORM or (1.5 * base_radius)
    if num_added>0:
        new = torch.randn((num_added, dim), device=wte.weight.device)
        new = new/(new.norm(dim=1, keepdim=True)+1e-8)*target_radius
        wte.weight.data[old_vocab:old_vocab+num_added] = new

# Only embeddings trainable
wte.weight.requires_grad=True
new_rows = torch.arange(old_vocab, old_vocab+num_added, device=wte.weight.device) if num_added>0 else None
base_rows = torch.arange(0, old_vocab, device=wte.weight.device)

def mask_grad(grad):
    if grad is None or new_rows is None: return grad
    grad[base_rows]=0; return grad
wte.weight.register_hook(mask_grad)

# Dataset
class BlockDataset(Dataset):
    def __init__(self, path, tokenizer, seq_len=512, stride=512):
        if not os.path.exists(path):
            stub = ("riverrun, past Eve and Adam's, from swerve of shore to bend of bay, "
                    "brings us by a commodius vicus of recirculation to Howth Castle and Environs. ")*2000
            text = stub
        else:
            text = open(path, "r", encoding="utf-8").read()
        ids = tokenizer(text, add_special_tokens=False)["input_ids"]
        blocks=[]
        for i in range(0, max(1, len(ids)-seq_len), stride):
            chunk = ids[i:i+seq_len]
            if len(chunk) >= seq_len//2:
                blocks.append(chunk[:seq_len])
        self.blocks = blocks
    def __len__(self): return len(self.blocks)
    def __getitem__(self, idx):
        ids = torch.tensor(self.blocks[idx], dtype=torch.long)
        return {"input_ids": ids, "labels": ids.clone(), "attention_mask": torch.ones_like(ids)}

train_ds = BlockDataset(CORPUS_TXT, tok, SEQ_LEN, STRIDE)
print(f"[Data] chunks={len(train_ds)}; tokens/step={SEQ_LEN}")

# Sentry callback
def has_weights(ck):
    return (ck/"adapter_model.safetensors").exists() or (ck/"pytorch_model.bin").exists()

class SentryMirror(TrainerCallback):
    def on_save(self, args, state, control, **kw):
        try:
            cks = sorted(LOCAL_RUN.glob("checkpoint-*"),
                        key=lambda p: int(p.name.split("-")[-1]),
                        reverse=True)
            if not cks:
                return
            ck = cks[0]
            if not has_weights(ck):
                print(f"[SENTRY] {ck.name} no weights, skip")
                return
            dst = SENTRY / ck.name
            if not dst.exists():
                print(f"[SENTRY] Mirroring {ck.name}...")
                shutil.copytree(ck, dst)
                print(f"[SENTRY] {ck.name} backed up to Drive")
            os.sync()
        except Exception as e:
            print(f"[SENTRY] ERROR: {e}")

# Custom trainer
class EmbOnlyTrainer(Trainer):
    def create_optimizer(self):
        from torch.optim import AdamW
        if not hasattr(self, "optimizer") or self.optimizer is None:
            self.optimizer = AdamW([{"params": [wte.weight], "lr": LR, "weight_decay": 0.0}],
                                   betas=(0.9, 0.999), eps=1e-8)
        return self.optimizer
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs, use_cache=False)
        loss = out.loss
        if torch.isnan(loss) or torch.isinf(loss):
            raise RuntimeError("NaN/Inf loss detected")
        return (loss, out) if return_outputs else loss

args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    max_steps=MAX_STEPS,
    warmup_steps=max(20, MAX_STEPS//20),
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    fp16=False,
    bf16=True,
    logging_steps=LOG_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=6,
    evaluation_strategy="no",
    report_to="none",
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
    max_grad_norm=1.0,
)

trainer = EmbOnlyTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    data_collator=None,
    callbacks=[SentryMirror()]
)

print(f"[Run] {MODEL_NAME} | steps={MAX_STEPS} | seq_len={SEQ_LEN}")
trainer.train()

# Save final artifacts to Drive
save_dir = RUN_DIR / "final"
save_dir.mkdir(exist_ok=True)
torch.save(wte.weight.detach().cpu(), save_dir / "embed_tokens.pt")
tok.save_pretrained(str(save_dir))
print(f"[SAVED] Final artifacts to {save_dir}")

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

max_steps is given, it will override any value given in num_train_epochs


[Data] chunks=167; tokens/step=512
[Run] meta-llama/Llama-3.2-1B | steps=1100 | seq_len=512


Step,Training Loss
20,3.1565
