<a href="https://colab.research.google.com/github/mahb97/Wake2vec/blob/main/Wake2Vec_Phase_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wake2Vec Phase 2: Full Model Fine-Tune

**Notebook:** `Wake2Vec Phase 2.ipynb`  
**Model:** TinyLlama-1.1B  
**Hardware:** Google Colab T4 GPU  

## Overview

Phase 2 performs full model fine-tuning on Finnegans Wake after P1's embedding-only warmup. Unlike P1, which trained only the expanded vocabulary embeddings, P2 unfreezes the entire model and uses LoRA adapters on attention and MLP layers to adapt the model's behavior to Wake's linguistic patterns.

### Prerequisites

- Phase 1 completed (embeddings trained to step 1300)
- P1 final artifacts saved in `/content/drive/MyDrive/wake2vecP1/final/`
- Finnegans Wake text file uploaded (`FW_TEXT.txt`)


### What This Notebook Does

1. **Environment Setup** - Installs compatible package versions for Nov 2025 Colab
2. **Load P1 State** - Loads tokenizer and trained embeddings from Phase 1
3. **Data Split** - Creates 90/10 train/validation split from Finnegans Wake
4. **Model Setup** - Initializes TinyLlama-1.1B with P1 embeddings and LoRA adapters
5. **Training** - Fine-tunes for 2 epochs with validation monitoring and early stopping
6. **Evaluation** - Generates loss curves and performance metrics

### Key Differences from Phase 1

| Aspect | Phase 1 | Phase 2 |
|--------|---------|---------|
| **Trainable params** | Embeddings only (~13M) | Full model (~1.1B + LoRA) |
| **LoRA targets** | `q_proj` (frozen) | `q_proj`, `v_proj`, MLP layers |
| **Learning rate** | 5e-4 | 2e-5 |
| **Training duration** | 1300 steps | 2 epochs |
| **Validation** | None | Held-out set with early stopping |
| **Batch size** | 1 | 8 |
| **Gradient accumulation** | 16 | 2 |
| **Objective** | Warm up Wake embeddings | Adapt model behavior to Wake |

## Hyperparameters
```python
EPOCHS = 2                    # Training epochs
LR = 2e-5                     # Learning rate
WARMUP_RATIO = 0.10           # 10% warmup
BATCH_SIZE = 8                # Per-device batch size
GRAD_ACCUM = 2                # Gradient accumulation steps
WEIGHT_DECAY = 0.01           # L2 regularization
SAVE_STEPS = 200              # Checkpoint frequency
SEQ_LEN = 512                 # Sequence length (256 for safer memory)
LORA_RANK = 8                 # LoRA adapter rank
EARLY_STOP_PATIENCE = 2       # Early stopping patience
```

**Effective batch size:** 8 × 2 = 16 samples per optimizer step

## Data

### Dataset Split
- **Total blocks:** ~1,740 (512 tokens each)
- **Train blocks:** ~1,566 (90%)
- **Validation blocks:** ~174 (10%)

### Vocabulary
- **Base tokenizer:** 32,000 tokens (TinyLlama-1.1B)
- **Wake additions:** ~447-534 tokens (from P1)
- **Final vocab size:** ~32,500-33,098 tokens

### Input Format
- **Sequence length:** 512 tokens (256 for P1 compatibility)
- **Stride:** 512 tokens (non-overlapping blocks)
- **Corpus:** Finnegans Wake plain text

## Memory Budget (T4)

| Component | VRAM Usage |
|-----------|------------|
| Model (4-bit quantized) | ~1.5 GB |
| LoRA adapters (rank 8) | ~0.3 GB |
| Optimizer states | ~2-3 GB |
| Activations (batch=8) | ~4-5 GB |
| **Total** | **~8-10 GB** |

**T4 capacity:** 15 GB  
**Safety margin:** ~5-7 GB (comfortable)

### Fallback Options (if OOM)
1. Reduce `SEQ_LEN` to 256 (P1 standard)
2. Reduce `BATCH_SIZE` to 4 and increase `GRAD_ACCUM` to 4
3. Reduce `LORA_RANK` to 4

## Training Strategy

### LoRA Configuration
```python
LoraConfig(
    r=8,                    # Low rank
    lora_alpha=16,          # 2× rank scaling
    lora_dropout=0.1,       # Regularization
    target_modules=[
        "q_proj",           # Query projections
        "v_proj",           # Value projections
        "gate_proj",        # MLP gate
        "up_proj",          
        "down_proj"        
    ]
)
```

### Training Loop
- **Epochs:** 2 full passes through training data
- **Validation:** Every 200 steps
- **Checkpointing:** Save every 200 steps, keep best 3
- **Early stopping:** Stop if validation loss doesn't improve for 2 checks
- **Backups:** Automatic Drive mirroring via Sentry callback

## Expected Outcomes

### Performance Targets
- **Training time:** 3-5 hours on T4 (2 epochs, faster than Llama-3.2-1B)
- **Final validation loss:** < 2.5 (perplexity ~12)
- **Convergence:** Within 2 epochs with early stopping

### Quality Indicators
- Validation loss decreases consistently
- Model generates coherent Wake-style text
- Wake tokens used appropriately in context
- P1 embedding geometry preserved
- No catastrophic forgetting of base vocabulary

## File Structure
```
wake2vecP2/
├── sentry_backups/          # Drive backups
│   ├── checkpoint-200/
│   ├── checkpoint-400/
│   └── checkpoint-best/
├── final/                   # Final model artifacts
│   ├── adapter_model.safetensors
│   ├── adapter_config.json
│   ├── tokenizer_config.json
│   └── special_tokens_map.json
└── p2_loss_curve.png        # Training/validation plot
```

## Environment Notes (Nov 2025 Colab)

This notebook uses aggressive package management to handle Colab's Nov 2025 updates:

- **Default Colab:** torch 2.8.0, CUDA 12.9, JAX 0.7.2
- **this stack:** torch 2.5.1+cu121, bitsandbytes 0.43.3, triton 3.1.0

See `COLAB_NOV2025_UPDATE.md` for detailed version compatibility notes.

**Last updated:** 2025-11-24  
**Model:** TinyLlama-1.1B (1.1B parameters)  
**Phase 1 completion:** Step 1300, loss ~0.079  
**Next phase:** P3 (morpheme-aware regularization)

 Environment Setup (nov 2025 colab)

In [None]:
# NUCLEAR OPT
!pip uninstall -y torch torchvision torchaudio triton bitsandbytes transformers accelerate peft \
    jax jaxlib flax cupy-cuda12x numba-cuda -y
!pip cache purge

# rr
import os
os.kill(os.getpid(), 9)

compat versions

In [None]:
# stop TorchAO
import os
os.environ["TRANSFORMERS_NO_TORCHAO"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# exact versions with explicit CUDA 12.1
!pip install --no-cache-dir \
    torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1+cu121 \
    --index-url https://download.pytorch.org/whl/cu121

!pip install -q --no-cache-dir \
    triton==3.1.0 \
    bitsandbytes==0.43.3 \
    transformers==4.45.2 \
    accelerate==0.34.2 \
    peft==0.13.2 \
    scikit-learn

# verify
import torch, bitsandbytes as bnb, triton
print("="*60)
print("PACKAGE VERSIONS")
print("="*60)
print(f"torch: {torch.__version__} | cuda: {torch.version.cuda}")
print(f"bitsandbytes: {bnb.__version__}")
print(f"triton: {triton.__version__}")
print("="*60)

verify bitsandbytes

In [None]:
# verify bitsandbytes CUDA
import torch
import bitsandbytes as bnb

print("Verification:")
print(f"  CUDA available: {torch.cuda.is_available()}")
print(f"  CUDA device: {torch.cuda.get_device_name(0)}")

# test 4-bit quantization works
try:
    from bitsandbytes.nn import Linear4bit
    test_layer = Linear4bit(10, 10, bias=False)
    test_layer.cuda()
    test_input = torch.randn(1, 10).cuda()
    with torch.no_grad():
        output = test_layer(test_input)
    print("  ✓ bitsandbytes CUDA working!")
except Exception as e:
    print(f"  ✗ bitsandbytes test failed: {e}")
    raise

drive and HF login

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from getpass import getpass
from huggingface_hub import login

HF_TOKEN = getpass("Paste your HF token (hidden): ")
login(token=HF_TOKEN, add_to_git_credential=True)

check GPU

In [None]:
import torch
import gc

# clean
torch.cuda.empty_cache()
gc.collect()

print("GPU Check:")
print(f"  Device: {torch.cuda.get_device_name(0)}")
print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

get P1 final state

In [None]:
import torch
from pathlib import Path
from transformers import AutoTokenizer

# P1 paths
P1_DIR = Path("/content/drive/MyDrive/wake_llama_P1/final")

# P1 artifacts
if not P1_DIR.exists():
    raise FileNotFoundError(f"P1 final directory not found: {P1_DIR}")

print("Loading P1 final state...")

# tokenizer (with Wake vocab)
tok = AutoTokenizer.from_pretrained(str(P1_DIR), use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
print(f"✓ Tokenizer loaded: {len(tok)} tokens")

# load embeds
embed_weights = torch.load(P1_DIR / "embed_tokens.pt", map_location="cpu")
print(f"✓ Embeddings loaded: {embed_weights.shape}")

# calc vocab expansion
BASE_VOCAB = 128256
WAKE_TOKENS = len(tok) - BASE_VOCAB
print(f"✓ Wake tokens added in P1: {WAKE_TOKENS}")

P2 config

In [None]:
from pathlib import Path

# P2 paths
RUN_DIR = Path("/content/drive/MyDrive/wake_llama_P2")
LOCAL_RUN = Path("/content/runs/wake_llama_P2")
SENTRY = RUN_DIR / "sentry_backups"

RUN_DIR.mkdir(parents=True, exist_ok=True)
LOCAL_RUN.mkdir(parents=True, exist_ok=True)
SENTRY.mkdir(parents=True, exist_ok=True)

# training config
MODEL_NAME = "meta-llama/Llama-3.2-1B"
FW_TEXT = "/content/FW_TEXT.txt"

# P2 hyperparameters
EPOCHS = 2
LR = 2e-5
WARMUP_RATIO = 0.10
BATCH_SIZE = 8
GRAD_ACCUM = 2
WEIGHT_DECAY = 0.01
SAVE_STEPS = 200
SEQ_LEN = 512
LORA_RANK = 8
EARLY_STOP_PATIENCE = 2

print("P2 Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LR}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRAD_ACCUM}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Sequence length: {SEQ_LEN}")

train/val split

In [None]:
import os
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class BlockDataset(Dataset):
    def __init__(self, blocks, tokenizer, seq_len=512):
        self.blocks = blocks
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        ids = torch.tensor(self.blocks[idx], dtype=torch.long)
        return {
            "input_ids": ids,
            "labels": ids.clone(),
            "attention_mask": torch.ones_like(ids)
        }

# tokenize full text
print("Loading Finnegans Wake...")
if not os.path.exists(FW_TEXT):
    raise FileNotFoundError(f"FW text not found: {FW_TEXT}")

with open(FW_TEXT, 'r', encoding='utf-8') as f:
    text = f.read()

ids = tok(text, add_special_tokens=False)["input_ids"]
print(f"Total tokens: {len(ids)}")

# blocks
blocks = []
stride = SEQ_LEN
for i in range(0, len(ids) - SEQ_LEN + 1, stride):
    chunk = ids[i:i + SEQ_LEN]
    if len(chunk) == SEQ_LEN:
        blocks.append(chunk)

print(f"Total blocks: {len(blocks)}")

# pplit 90/10 train/val
train_blocks, val_blocks = train_test_split(
    blocks,
    test_size=0.10,
    random_state=42
)

print(f"Train blocks: {len(train_blocks)}")
print(f"Val blocks: {len(val_blocks)}")

# datasets
train_ds = BlockDataset(train_blocks, tok, SEQ_LEN)
val_ds = BlockDataset(val_blocks, tok, SEQ_LEN)

print(f"✓ Datasets ready")

get model with P1 embeds

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, set_seed
from peft import LoraConfig, get_peft_model

set_seed(42)

# 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload=True
)

print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    max_memory={0: "13GB", "cpu": "30GB"}
)

model.config.use_cache = False
model.config.attn_implementation = "eager"
model.config.tie_word_embeddings = True
if hasattr(model, "tie_weights"):
    model.tie_weights()

print("✓ Base model loaded")

# resize embeddings to match P1 vocab
print(f"Resizing embeddings to {len(tok)}...")
model.resize_token_embeddings(len(tok))

# load P1 embeds
wte = model.get_input_embeddings()
if hasattr(model, "lm_head"):
    model.lm_head.weight = wte.weight

with torch.no_grad():
    wte.weight.copy_(embed_weights.to(wte.weight.device))

print("✓ P1 embeddings loaded")

# get LoRA adapters for P2
print("Adding LoRA adapters...")
peft_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_RANK * 2,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainer with val

In [None]:
import shutil
from transformers import TrainingArguments, Trainer, TrainerCallback, EarlyStoppingCallback

# backup callback
def has_weights(ck):
    return (ck / "adapter_model.safetensors").exists() or (ck / "pytorch_model.bin").exists()

class SentryMirror(TrainerCallback):
    def on_save(self, args, state, control, **kw):
        try:
            cks = sorted(
                LOCAL_RUN.glob("checkpoint-*"),
                key=lambda p: int(p.name.split("-")[-1]),
                reverse=True
            )
            if not cks:
                return

            ck = cks[0]
            if not has_weights(ck):
                print(f"[SENTRY] {ck.name} no weights, skip")
                return

            dst = SENTRY / ck.name
            if not dst.exists():
                print(f"[SENTRY] Mirroring {ck.name}...")
                shutil.copytree(ck, dst)
                print(f"[SENTRY] {ck.name} backed up to Drive")
            os.sync()
        except Exception as e:
            print(f"[SENTRY] ERROR: {e}")

# training arguments
args = TrainingArguments(
    output_dir=str(LOCAL_RUN),
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    lr_scheduler_type="cosine",
    weight_decay=WEIGHT_DECAY,
    fp16=False,
    bf16=True,
    logging_steps=50,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    evaluation_strategy="steps",
    eval_steps=SAVE_STEPS,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
    max_grad_norm=1.0,
)

# trainer, early stopping
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    callbacks=[
        SentryMirror(),
        EarlyStoppingCallback(early_stopping_patience=EARLY_STOP_PATIENCE)
    ]
)

print("✓ Trainer ready")
print(f"Training for {EPOCHS} epochs with validation every {SAVE_STEPS} steps")

hit it

In [None]:
print("="*80)
print("WAKE2VEC PHASE 2: FULL MODEL FINE-TUNE")
print("="*80)
print(f"Train samples: {len(train_ds)}")
print(f"Val samples: {len(val_ds)}")
print(f"Total epochs: {EPOCHS}")
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
print("="*80)

# train
trainer.train()

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)

save

In [None]:
# final model
final_dir = RUN_DIR / "final"
final_dir.mkdir(exist_ok=True)

print("Saving final P2 model...")
model.save_pretrained(str(final_dir))
tok.save_pretrained(str(final_dir))

print(f"✓ Model saved to {final_dir}")
print(f"✓ Phase 2 complete!")

In [None]:
eval

In [None]:
import matplotlib.pyplot as plt
import json

# load training history
history_file = LOCAL_RUN / "trainer_state.json"
if history_file.exists():
    with open(history_file) as f:
        state = json.load(f)

    logs = state.get("log_history", [])

    # extract losses
    train_loss = [(d["step"], d["loss"]) for d in logs if "loss" in d and "eval_loss" not in d]
    val_loss = [(d["step"], d["eval_loss"]) for d in logs if "eval_loss" in d]

    if train_loss and val_loss:
        # plot
        plt.figure(figsize=(12, 6))

        train_steps, train_losses = zip(*train_loss)
        val_steps, val_losses = zip(*val_loss)

        plt.plot(train_steps, train_losses, 'o-', label="Training Loss", alpha=0.7)
        plt.plot(val_steps, val_losses, 's-', label="Validation Loss", alpha=0.7)
        plt.xlabel("Steps")
        plt.ylabel("Loss")
        plt.title("Wake2Vec P2: Training & Validation Loss")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plot_path = RUN_DIR / "p2_loss_curve.png"
        plt.savefig(plot_path, dpi=150, bbox_inches="tight")
        print(f"✓ Plot saved: {plot_path}")
        plt.show()

        # sum
        print("\nP2 Summary:")
        print(f"  Final train loss: {train_losses[-1]:.4f}")
        print(f"  Final val loss: {val_losses[-1]:.4f}")
        print(f"  Best val loss: {min(val_losses):.4f}")