<a href="https://colab.research.google.com/github/mahb97/Wakeifier/blob/main/Wake2vec_Llama_3_1_8b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Wake2Vec “*F* the Embeddings” (T4 Edn)
This notebook is a colab-friendly, embedding-only finetune pipeline for large decoder LMs (Mistral-7B / Llama-2-13B / Llama-3.1-8B) using a Wake lexicon injection. It adds Joyce-specific tokens, initializes them on a sphere, and trains only the input embedding rows (optionally with a minimal LoRA r=1 on q_proj to satisfy quantized-training rules). The goal is to bend local geometry (neighbors, isotropy) while keeping the rest of the model frozen.








# p2
max_steps: 1500-2500

lr: 5e-4 → 1e-4 (cosine decay)

batch_size: 1

grad_accum: 16

Custom loss = LM_loss + λ₁·attraction + λ₂·repulsion + λ₃·morphological + λ₄·adversarial

guardrail

In [None]:
# NUCLEAR OPT
!pip uninstall -y torch torchvision torchaudio triton bitsandbytes transformers accelerate peft fastai timm

# rr
import os
os.kill(os.getpid(), 9)

Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
Found existing installation: triton 3.4.0
Uninstalling triton-3.4.0:
  Successfully uninstalled triton-3.4.0
[0mFound existing installation: transformers 4.57.0
Uninstalling transformers-4.57.0:
  Successfully uninstalled transformers-4.57.0
Found existing installation: accelerate 1.10.1
Uninstalling accelerate-1.10.1:
  Successfully uninstalled accelerate-1.10.1
Found existing installation: peft 0.17.1
Uninstalling peft-0.17.1:
  Successfully uninstalled peft-0.17.1
Found existing installation: fastai 2.8.4
Uninstalling fastai-2.8.4:
  Successfully uninstalled fastai-2.8.4
Found exis

In [1]:
# Stop TorchAO
import os
os.environ["TRANSFORMERS_NO_TORCHAO"] = "1"

# compat versions
!pip install -q --no-cache-dir \
    torch==2.5.1 \
    triton==3.1.0 \
    bitsandbytes==0.43.3 \
    transformers==4.45.2 \
    accelerate==0.34.2 \
    peft==0.13.2

# Verify
import torch, bitsandbytes as bnb, triton
print("torch:", torch.__version__, "| cuda:", torch.version.cuda)
print("bnb:", bnb.__version__, "| triton:", triton.__version__)

torch: 2.5.1+cu124 | cuda: 12.4
bnb: 0.43.3 | triton: 3.1.0


In [2]:
# stop TorchAO
import os
os.environ["TRANSFORMERS_NO_TORCHAO"] = "1"

from getpass import getpass
from huggingface_hub import login

HF_TOKEN = getpass("Paste your HF token (hidden): ")
login(token=HF_TOKEN, add_to_git_credential=True)

Paste your HF token (hidden): ··········


In [5]:
# Val dataset
from torch.utils.data import Dataset
import torch

class BlockDataset(Dataset):
    """Sliding window dataset for causal LM training."""
    def __init__(self, txt_path, tokenizer, seq_len=512, stride=512):
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokens = tokenizer(text, add_special_tokens=False)['input_ids']
        self.seq_len = seq_len
        self.stride = stride
        self.starts = list(range(0, len(self.tokens) - seq_len + 1, stride))

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        start = self.starts[idx]
        chunk = self.tokens[start:start + self.seq_len]
        return {'input_ids': torch.tensor(chunk, dtype=torch.long)}

from transformers import AutoTokenizer

print("Testing dataset creation...")
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_fast=True)

ds = BlockDataset("/content/FW_TEXT.txt", tok, seq_len=1024, stride=1024)
print(f"  Dataset size: {len(ds)} blocks")
print(f"  First block shape: {ds[0]['input_ids'].shape}")

print("\nDataset validated")

Testing dataset creation...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Token indices sequence length is longer than the specified maximum sequence length for this model (378828 > 131072). Running this sequence through the model will result in indexing errors


  Dataset size: 369 blocks
  First block shape: torch.Size([1024])

Dataset validated


In [6]:
# Helper functions
def read_lines(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

class BlockDataset(Dataset):
    """Sliding window dataset for causal LM training."""
    def __init__(self, txt_path, tokenizer, seq_len=512, stride=512):
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokens = tokenizer(text, add_special_tokens=False)['input_ids']
        self.seq_len = seq_len
        self.stride = stride
        self.starts = list(range(0, len(self.tokens) - seq_len + 1, stride))

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        start = self.starts[idx]
        chunk = self.tokens[start:start + self.seq_len]
        return {'input_ids': torch.tensor(chunk, dtype=torch.long)}

In [8]:
# Pre-train
import os

print("GPU Check:")
print(f"  Device: {torch.cuda.get_device_name(0)}")
print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

print("\nFile Check:")
files = {
    "Wake Lexicon": "/content/wake_lexicon.txt",
    "FW Text": "/content/FW_TEXT.txt"
}
for name, path in files.items():
    exists = os.path.exists(path)
    status = "Found" if exists else "MISSING"
    print(f"  {name}: {status}")
    if exists:
        size = os.path.getsize(path) / 1024
        print(f"    Size: {size:.1f} KB")

GPU Check:
  Device: Tesla T4
  Memory: 15.83 GB
  Allocated: 0.00 GB

File Check:
  Wake Lexicon: Found
    Size: 403.0 KB
  FW Text: Found
    Size: 1358.4 KB


In [10]:
import gc
import torch

# Verify clean slate
torch.cuda.empty_cache()
gc.collect()

print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

# Then load model
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_NAME = "meta-llama/Llama-3.1-8B" # Added this line to define MODEL_NAME

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb,
    torch_dtype=torch.float16,
    device_map="auto"
)

GPU memory allocated: 0.00 GB
GPU memory cached: 0.00 GB


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
# memory footprint
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

MODEL_NAME = "meta-llama/Llama-3.1-8B"

print(f"Testing {MODEL_NAME} load on T4")

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

print("\nLoading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
print(f"  Vocab size: {len(tok)}")

print("\nLoading model with 4-bit quantization...")
torch.cuda.empty_cache()
initial_mem = torch.cuda.memory_allocated(0) / 1e9

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb,
    torch_dtype=torch.float16,
    device_map="auto"
)

loaded_mem = torch.cuda.memory_allocated(0) / 1e9
print(f"  Model loaded: {loaded_mem:.2f} GB")
print(f"  Delta: {loaded_mem - initial_mem:.2f} GB")

# Validate forward pass
print("\nTesting forward pass...")
test_ids = torch.tensor([[1, 2, 3, 4, 5]], device="cuda")
with torch.no_grad():
    out = model(test_ids)
    print(f"  Output shape: {out.logits.shape}")
    print(f"  Memory after forward: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

print(f"\nPeak memory: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")
print("Model validated successfully")

# Cleanup
del model, tok
torch.cuda.empty_cache()
print("Memory cleared for main run")

Testing meta-llama/Llama-3.1-8B load on T4

Loading tokenizer...
  Vocab size: 128256

Loading model with 4-bit quantization...


config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

  Model loaded: 5.70 GB
  Delta: 5.70 GB

Testing forward pass...


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


  Output shape: torch.Size([1, 5, 128256])
  Memory after forward: 5.71 GB

Peak memory: 5.83 GB
Model validated successfully
Memory cleared for main run


In [None]:
import os, math, json, random, torch, torch.nn as nn
from torch.utils.data import Dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
                          DataCollatorForLanguageModeling, TrainingArguments, Trainer, set_seed)
from peft import LoraConfig, get_peft_model

# CONFIG
SEED = 42
MODEL_NAME = "meta-llama/Llama-3.1-8B"
WAKE_LEX_PATH = "/content/wake_lexicon.txt"
CORPUS_TXT    = "/content/FW_TEXT.txt"
RUN_DIR = "/content/wake_llama_embs"
SEQ_LEN = 1024
STRIDE  = 1024
MAX_STEPS = 1100
LOG_STEPS = 20
SAVE_STEPS = 200
LR = 8e-4
GRAD_ACCUM = 8
REPULSION_W = 0.0
TARGET_NORM = None
MAX_ROW_NORM = None
REPORT_SAMPLE = 1500

os.makedirs(RUN_DIR, exist_ok=True)
set_seed(SEED)
print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| GPU:", torch.cuda.get_device_name(0))

def read_lines(p, fb=None):
    if not os.path.exists(p): return fb or []
    return [x.strip() for x in open(p, encoding="utf-8") if x.strip()]

# torch-only dataset
class BlockDataset(Dataset):
    def __init__(self, path, tok, seq_len=1024, stride=1024):
        if not os.path.exists(path):
            stub = ("riverrun, past Eve and Adam's, from swerve of shore to bend of bay, "
                    "brings us by a commodius vicus of recirculation to Howth Castle and Environs. ")
            text = stub * 2000
        else:
            text = open(path, "r", encoding="utf-8").read()
        ids = tok(text, add_special_tokens=False)["input_ids"]
        blocks=[]
        for i in range(0, max(1, len(ids)-seq_len), stride):
            chunk = ids[i:i+seq_len]
            if len(chunk) >= seq_len//2:
                blocks.append(chunk[:seq_len])
        self.blocks = blocks
    def __len__(self): return len(self.blocks)
    def __getitem__(self, idx):
        ids = torch.tensor(self.blocks[idx], dtype=torch.long)
        return {"input_ids": ids, "labels": ids.clone()}

# Clear GPU memory
import gc
import torch

torch.cuda.empty_cache()
gc.collect()

print(f"GPU memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True
)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb,
    torch_dtype=torch.float16,
    device_map="auto",
    max_memory={0: "13GB", "cpu": "30GB"}
)

model.config.attn_implementation = "eager"
if hasattr(model, "tie_weights"):
    model.tie_weights()
model.config.tie_word_embeddings = True

# Minimal PEFT adapter
peft_cfg = LoraConfig(
    r=1,
    lora_alpha=1,
    lora_dropout=0.0,
    target_modules=["q_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_cfg)

# Wake lex
wake = read_lines(WAKE_LEX_PATH)
missing = [t for t in wake if tok.convert_tokens_to_ids(t)==tok.unk_token_id]
num_added = tok.add_tokens(missing, special_tokens=False)
old_vocab = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok))
wte = model.get_input_embeddings()
if hasattr(model, "lm_head"): model.lm_head.weight = wte.weight

# Spherical kick
with torch.no_grad():
    base = wte.weight[:old_vocab]; dim = base.shape[1]
    std = base.std().item(); base_radius = std * math.sqrt(dim)
    target_radius = TARGET_NORM or (1.5 * base_radius)
    if num_added>0:
        new = torch.randn((num_added, dim), device=wte.weight.device)
        new = new/(new.norm(dim=1, keepdim=True)+1e-8)*target_radius
        wte.weight.data[old_vocab:old_vocab+num_added] = new
print(f"[Init] new rows: {num_added} | target L2 ≈ {target_radius:.3f}")

# Trainables: only embeddings
for n,p in model.named_parameters(): p.requires_grad=False
wte.weight.requires_grad=True
new_rows = torch.arange(old_vocab, old_vocab+num_added, device=wte.weight.device) if num_added>0 else None
base_rows = torch.arange(0, old_vocab, device=wte.weight.device)
def mask_grad(grad):
    if grad is None or new_rows is None: return grad
    grad[base_rows]=0; return grad
wte.weight.register_hook(mask_grad)

def clamp_rows_(emb, max_norm):
    if max_norm is None or new_rows is None: return
    rows = emb.weight.data[old_vocab:old_vocab+num_added]
    norms = rows.norm(dim=1, keepdim=True).clamp_min(1e-8)
    scale = (max_norm/norms).clamp_max(1.0)
    emb.weight.data[old_vocab:old_vocab+num_added] = rows*scale

# Data
train_ds = BlockDataset(CORPUS_TXT, tok, seq_len=SEQ_LEN, stride=STRIDE)
coll = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
model.gradient_checkpointing_enable()

# Geometry (pre)
@torch.no_grad()
def isotropy(M):
    u, s, v = torch.pca_lowrank(M - M.mean(0, keepdim=True), q=min(128, M.shape[1]-1))
    return float((s.max() / s.min().clamp_min(1e-8)))
@torch.no_grad()
def pip_loss(A, B):
    return float(torch.norm((A@A.T)-(B@B.T), p='fro')/(A.shape[0]**2))
@torch.no_grad()
def topk_overlap(M1, M2, k=10, sample=1000):
    W1 = M1/(M1.norm(dim=1, keepdim=True)+1e-8); W2 = M2/(M2.norm(dim=1, keepdim=True)+1e-8)
    vocab = W1.shape[0]; idxs = random.sample(range(vocab), min(sample, vocab))
    acc = 0.0
    for i in idxs:
        c1 = torch.topk(W1 @ W1[i], k+1).indices.tolist(); c1=[j for j in c1 if j!=i][:k]
        c2 = torch.topk(W2 @ W2[i], k+1).indices.tolist(); c2=[j for j in c2 if j!=i][:k]
        acc += len(set(c1)&set(c2))/k
    return float(acc/len(idxs))

with torch.no_grad():
    pre_full = wte.weight.detach().clone()
    pre_new  = pre_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, pre_full.shape[1], device=wte.weight.device)

# Trainer
class EmbOnlyTrainer(Trainer):
    def create_optimizer(self):
        from torch.optim import AdamW
        if not hasattr(self, "optimizer") or self.optimizer is None:
            self.optimizer = AdamW([{"params": [wte.weight], "lr": LR, "weight_decay": 0.0}],
                                   betas=(0.9, 0.999), eps=1e-8)
        return self.optimizer
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs); loss = out.loss
        if num_added and num_added>1 and REPULSION_W>0:
            E = model.get_input_embeddings().weight[old_vocab:old_vocab+num_added]
            E = E - E.mean(0, keepdim=True); E = E/(E.norm(dim=1, keepdim=True)+1e-8)
            sims = (E @ E.t()); repul = (sims - torch.eye(E.shape[0], device=E.device)).pow(2).mean()
            loss = loss + REPULSION_W*repul
        return (loss, out) if return_outputs else loss
    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        clamp_rows_(model.get_input_embeddings(), MAX_ROW_NORM)
        return out

args = TrainingArguments(
    output_dir=RUN_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    max_steps=MAX_STEPS,
    warmup_steps=max(20, MAX_STEPS//20),
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    fp16=False,
    bf16=False,
    logging_steps=LOG_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    eval_strategy="no",
    report_to="none",
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
)

trainer = EmbOnlyTrainer(model=model, args=args, data_collator=coll, train_dataset=train_ds)
print(f"[Run] emb-only | base={MODEL_NAME} | steps={MAX_STEPS} | seq_len={SEQ_LEN} | accum={GRAD_ACCUM} | LR={LR}")
trainer.train()

# Save
save_dir = os.path.join(RUN_DIR, "embedding_only"); os.makedirs(save_dir, exist_ok=True)
torch.save(wte.weight.detach().cpu(), os.path.join(save_dir, "embed_tokens.pt"))
with open(os.path.join(save_dir, "added_tokens.json"), "w") as f:
    json.dump({"added_tokens": missing, "old_vocab": old_vocab, "num_added": num_added}, f, indent=2)
tok.save_pretrained(RUN_DIR)

# Geometry (post)
with torch.no_grad():
    post_full = wte.weight.detach().clone()
    post_new  = post_full[old_vocab:old_vocab+num_added].clone() if num_added>0 else torch.empty(0, pre_full.shape[1], device=wte.weight.device)

report = {
  "model": MODEL_NAME, "added_tokens": int(num_added), "old_vocab": int(old_vocab),
  "pip_loss_full": pip_loss(pre_full, post_full),
  "topk_overlap_all": topk_overlap(pre_full, post_full, k=10, sample=min(REPORT_SAMPLE, pre_full.shape[0]-1)),
  "isotropy_pre": isotropy(pre_full), "isotropy_post": isotropy(post_full),
  "pip_loss_new_rows": (pip_loss(pre_new, post_new) if num_added>1 else None),
  "isotropy_new_rows": (isotropy(post_new) if num_added>1 else None),
}
json.dump(report, open(os.path.join(RUN_DIR, "geometry_report.json"), "w"), indent=2)
print("\n GEOM REPORT")
for k,v in report.items(): print(f"{k}: {v}")