[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mmcmanus1/rlhf-canary/blob/main/notebooks/09_quantization_and_memory.ipynb)

# Quantization & Memory Optimization: Running Canaries on Limited Hardware

Use 4-bit and 8-bit quantization to run canaries on memory-constrained GPUs. Learn baseline management workflows and compare quantized vs full-precision training.

**What you'll learn:**
1. Enabling 4-bit quantization with bitsandbytes
2. Enabling 8-bit quantization
3. Memory savings: quantized vs full precision
4. Performance impact of quantization
5. Baseline management workflows
6. Comparing baselines across configurations

**Requirements:** GPU runtime (Runtime > Change runtime type > T4 GPU)

**Runtime:** ~20-25 minutes (multiple runs for comparison)

## 1. Setup

In [None]:
import os
import re
import sys

print("Starting Environment Setup...")

# --- 1. Clone or update the repo ---
if not os.path.exists("/content/rlhf-canary"):
    !git clone https://github.com/mmcmanus1/rlhf-canary.git /content/rlhf-canary
else:
    !cd /content/rlhf-canary && git pull --ff-only

%cd /content/rlhf-canary

# --- 2. Force-Install the "Safe Harbor" Stack ---
!pip install "trl==0.11.4" "transformers==4.44.2" "peft==0.12.0" "accelerate==0.34.2" "tokenizers==0.19.1" --force-reinstall --no-deps --quiet
!pip install -q datasets pydantic click PyYAML bitsandbytes
print("Libraries installed (TRL 0.11.4 / Transformers 4.44.2)")

# --- 3. Patch pyproject.toml ---
project_file = "/content/rlhf-canary/pyproject.toml"
if os.path.exists(project_file):
    with open(project_file, "r") as f:
        content = f.read()
    
    if "trl==0.11.4" not in content:
        content = re.sub(r'trl[<>=!~]+[\d\.]+', 'trl==0.11.4', content)
        with open(project_file, "w") as f:
            f.write(content)
        print("Config file patched to lock TRL 0.11.4")

# --- 4. Patch Source Code ---
runner_file = "/content/rlhf-canary/canary/runner/local.py"
if os.path.exists(runner_file):
    with open(runner_file, "r") as f:
        code = f.read()
    
    if "processing_class=" in code:
        code = code.replace("processing_class=", "tokenizer=")
        with open(runner_file, "w") as f:
            f.write(code)
        print("Code patched: Reverted 'processing_class' to 'tokenizer'")
    else:
        print("Code is already compatible.")

# --- 5. Install the package ---
!pip install -e . --quiet

print("Environment Ready!")

In [None]:
# Verify GPU is available
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    
    # Check if bitsandbytes is working
    try:
        import bitsandbytes
        print(f"bitsandbytes: {bitsandbytes.__version__}")
    except ImportError:
        print("WARNING: bitsandbytes not available")

## 2. Understanding Quantization Options

RLHF Canary supports two quantization modes via bitsandbytes:

| Option | Memory Savings | Speed Impact | Use Case |
|--------|---------------|--------------|----------|
| `load_in_4bit` | ~75% | Slight slowdown | Very limited VRAM (<8GB) |
| `load_in_8bit` | ~50% | Minimal | Moderate VRAM (8-12GB) |
| None (FP16) | Baseline | Fastest | Ample VRAM (16GB+) |

### How It Works

When `load_in_4bit: true` is set, RLHF Canary uses:

```python
BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,  # Extra compression
    bnb_4bit_quant_type="nf4",       # Normalized float 4-bit
)
```

For 8-bit:
```python
BitsAndBytesConfig(load_in_8bit=True)
```

## 3. Creating Quantized Configurations

Let's create three configs: full precision (FP16), 8-bit, and 4-bit.

In [None]:
# Full precision config (baseline)
fp16_config = """
name: dpo_fp16
description: DPO canary with full precision (FP16)

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

# No quantization - full FP16
load_in_4bit: false
load_in_8bit: false

training_type: dpo
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
max_length: 256
warmup_steps: 5

beta: 0.1
max_prompt_length: 64

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 512
seed: 42

output_dir: ./canary_output
metrics_warmup_steps: 5
"""

!mkdir -p configs/quantization
with open('configs/quantization/dpo_fp16.yaml', 'w') as f:
    f.write(fp16_config)

print("Created configs/quantization/dpo_fp16.yaml")

In [None]:
# 8-bit quantization config
int8_config = """
name: dpo_8bit
description: DPO canary with 8-bit quantization

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

# 8-bit quantization
load_in_4bit: false
load_in_8bit: true

training_type: dpo
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
max_length: 256
warmup_steps: 5

beta: 0.1
max_prompt_length: 64

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 512
seed: 42

output_dir: ./canary_output
metrics_warmup_steps: 5
"""

with open('configs/quantization/dpo_8bit.yaml', 'w') as f:
    f.write(int8_config)

print("Created configs/quantization/dpo_8bit.yaml")

In [None]:
# 4-bit quantization config
int4_config = """
name: dpo_4bit
description: DPO canary with 4-bit quantization (NF4)

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

# 4-bit quantization (NF4 with double quantization)
load_in_4bit: true
load_in_8bit: false

training_type: dpo
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
max_length: 256
warmup_steps: 5

beta: 0.1
max_prompt_length: 64

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 512
seed: 42

output_dir: ./canary_output
metrics_warmup_steps: 5
"""

with open('configs/quantization/dpo_4bit.yaml', 'w') as f:
    f.write(int4_config)

print("Created configs/quantization/dpo_4bit.yaml")

In [None]:
# View all configs
!ls -la configs/quantization/

## 4. Running Full Precision Baseline

First, let's establish a full precision baseline.

In [None]:
# Run full precision canary
!python -m canary.cli run configs/quantization/dpo_fp16.yaml -o ./canary_output/fp16

In [None]:
# Load and display FP16 metrics
import json
from pathlib import Path

fp16_path = next(Path('./canary_output/fp16').rglob('metrics.json'))

with open(fp16_path) as f:
    fp16_metrics = json.load(f)

print("FP16 (Full Precision) Results")
print("=" * 40)
print(f"Step time (mean): {fp16_metrics['perf']['step_time']['mean']:.4f}s")
print(f"Tokens/sec: {fp16_metrics['perf']['approx_tokens_per_sec']:.0f}")
print(f"Peak memory: {fp16_metrics['perf']['max_mem_mb']:.0f} MB")
print(f"NaN steps: {fp16_metrics['stability']['nan_steps']}")

## 5. Running 8-bit Quantized Canary

In [None]:
# Run 8-bit quantized canary
!python -m canary.cli run configs/quantization/dpo_8bit.yaml -o ./canary_output/8bit

In [None]:
# Load and display 8-bit metrics
int8_path = next(Path('./canary_output/8bit').rglob('metrics.json'))

with open(int8_path) as f:
    int8_metrics = json.load(f)

print("8-bit Quantization Results")
print("=" * 40)
print(f"Step time (mean): {int8_metrics['perf']['step_time']['mean']:.4f}s")
print(f"Tokens/sec: {int8_metrics['perf']['approx_tokens_per_sec']:.0f}")
print(f"Peak memory: {int8_metrics['perf']['max_mem_mb']:.0f} MB")
print(f"NaN steps: {int8_metrics['stability']['nan_steps']}")

## 6. Running 4-bit Quantized Canary

In [None]:
# Run 4-bit quantized canary
!python -m canary.cli run configs/quantization/dpo_4bit.yaml -o ./canary_output/4bit

In [None]:
# Load and display 4-bit metrics
int4_path = next(Path('./canary_output/4bit').rglob('metrics.json'))

with open(int4_path) as f:
    int4_metrics = json.load(f)

print("4-bit Quantization Results")
print("=" * 40)
print(f"Step time (mean): {int4_metrics['perf']['step_time']['mean']:.4f}s")
print(f"Tokens/sec: {int4_metrics['perf']['approx_tokens_per_sec']:.0f}")
print(f"Peak memory: {int4_metrics['perf']['max_mem_mb']:.0f} MB")
print(f"NaN steps: {int4_metrics['stability']['nan_steps']}")

## 7. Memory Comparison Analysis

In [None]:
# Side-by-side comparison
print("Memory & Performance Comparison")
print("=" * 70)
print(f"{'Metric':<25} {'FP16':>12} {'8-bit':>12} {'4-bit':>12}")
print("-" * 70)

# Memory
fp16_mem = fp16_metrics['perf']['max_mem_mb']
int8_mem = int8_metrics['perf']['max_mem_mb']
int4_mem = int4_metrics['perf']['max_mem_mb']

print(f"{'Peak Memory (MB)':<25} {fp16_mem:>12.0f} {int8_mem:>12.0f} {int4_mem:>12.0f}")
print(f"{'Memory Savings':<25} {'baseline':>12} {(1 - int8_mem/fp16_mem)*100:>11.1f}% {(1 - int4_mem/fp16_mem)*100:>11.1f}%")

# Step time
fp16_time = fp16_metrics['perf']['step_time']['mean']
int8_time = int8_metrics['perf']['step_time']['mean']
int4_time = int4_metrics['perf']['step_time']['mean']

print(f"{'Step Time (s)':<25} {fp16_time:>12.4f} {int8_time:>12.4f} {int4_time:>12.4f}")
print(f"{'Slowdown':<25} {'baseline':>12} {(int8_time/fp16_time - 1)*100:>+11.1f}% {(int4_time/fp16_time - 1)*100:>+11.1f}%")

# Throughput
fp16_tps = fp16_metrics['perf']['approx_tokens_per_sec']
int8_tps = int8_metrics['perf']['approx_tokens_per_sec']
int4_tps = int4_metrics['perf']['approx_tokens_per_sec']

print(f"{'Tokens/sec':<25} {fp16_tps:>12.0f} {int8_tps:>12.0f} {int4_tps:>12.0f}")
print(f"{'TPS Change':<25} {'baseline':>12} {(int8_tps/fp16_tps - 1)*100:>+11.1f}% {(int4_tps/fp16_tps - 1)*100:>+11.1f}%")

In [None]:
# Visual comparison (text-based chart)
def text_bar(value, max_val, width=40, label=""):
    filled = int((value / max_val) * width)
    bar = "|" + "=" * filled + " " * (width - filled) + "|"
    return f"{label:<10} {bar} {value:.0f}"

max_mem = max(fp16_mem, int8_mem, int4_mem)

print("\nMemory Usage Comparison")
print("=" * 60)
print(text_bar(fp16_mem, max_mem, label="FP16"))
print(text_bar(int8_mem, max_mem, label="8-bit"))
print(text_bar(int4_mem, max_mem, label="4-bit"))

### When to Use Each Mode

| GPU Memory | Recommended | Notes |
|------------|-------------|-------|
| 4-6 GB | 4-bit | Essential for small GPUs |
| 8-12 GB | 8-bit | Good balance |
| 16+ GB | FP16 | Best quality and speed |

**Trade-offs:**
- **4-bit**: Significant memory savings, slight quality/speed impact
- **8-bit**: Moderate savings, minimal quality impact
- **FP16**: Baseline quality and speed, requires more memory

## 8. Baseline Management Workflows

Proper baseline management is crucial for meaningful regression detection.

In [None]:
# Create baseline directory structure
!mkdir -p baselines/t4/fp16
!mkdir -p baselines/t4/8bit
!mkdir -p baselines/t4/4bit

print("Baseline directory structure:")
!find baselines -type d | sort

In [None]:
# Save baselines using canary save-baseline command
!python -m canary.cli save-baseline {fp16_path} baselines/t4/fp16/dpo_smoke.json
!python -m canary.cli save-baseline {int8_path} baselines/t4/8bit/dpo_smoke.json
!python -m canary.cli save-baseline {int4_path} baselines/t4/4bit/dpo_smoke.json

print("\nSaved baselines:")
!find baselines -name "*.json" | sort

### Recommended Baseline Organization

```
baselines/
├── t4/                    # GPU type
│   ├── fp16/
│   │   ├── dpo_smoke.json
│   │   └── ppo_smoke.json
│   ├── 8bit/
│   │   └── dpo_smoke.json
│   └── 4bit/
│       └── dpo_smoke.json
├── a100/                  # Different GPU
│   └── fp16/
│       └── dpo_smoke.json
└── README.md              # Document baselines
```

**Key principles:**
1. **Separate by GPU** - Performance varies significantly
2. **Separate by quantization** - Different memory/speed profiles
3. **Version control baselines** - Track changes over time

In [None]:
# Document baselines
readme = """
# Baselines

## T4 GPU (16GB VRAM)

| Config | Quantization | Memory | Step Time | Created |
|--------|--------------|--------|-----------|----------|
| dpo_smoke | FP16 | {fp16_mem:.0f}MB | {fp16_time:.4f}s | {date} |
| dpo_smoke | 8-bit | {int8_mem:.0f}MB | {int8_time:.4f}s | {date} |
| dpo_smoke | 4-bit | {int4_mem:.0f}MB | {int4_time:.4f}s | {date} |

## Usage

Compare to appropriate baseline based on your hardware and quantization:

```bash
# FP16 comparison
canary compare current.json baselines/t4/fp16/dpo_smoke.json

# 4-bit comparison
canary compare current.json baselines/t4/4bit/dpo_smoke.json
```
""".format(
    fp16_mem=fp16_mem, fp16_time=fp16_time,
    int8_mem=int8_mem, int8_time=int8_time,
    int4_mem=int4_mem, int4_time=int4_time,
    date="2024-01-01"  # Replace with actual date
)

with open('baselines/README.md', 'w') as f:
    f.write(readme)

print("Created baselines/README.md")

## 9. Cross-Configuration Comparison

When comparing runs with different quantization levels, be aware of expected differences.

In [None]:
# Compare 4-bit run to FP16 baseline (expect differences!)
print("Comparing 4-bit run to FP16 baseline")
print("(This will likely show 'regressions' due to quantization overhead)\n")

!python -m canary.cli compare {int4_path} baselines/t4/fp16/dpo_smoke.json --threshold-tier smoke

In [None]:
# Compare 4-bit run to 4-bit baseline (apples to apples)
print("Comparing 4-bit run to 4-bit baseline")
print("(This is the correct comparison for regression detection)\n")

!python -m canary.cli compare {int4_path} baselines/t4/4bit/dpo_smoke.json --threshold-tier smoke

### Cross-Configuration Comparison Guidelines

| Comparison | Valid? | Notes |
|------------|--------|-------|
| FP16 vs FP16 | Yes | Ideal for regression detection |
| 4-bit vs 4-bit | Yes | Apples to apples |
| 4-bit vs FP16 | No | Different performance profiles |
| T4 vs A100 | No | Hardware too different |

**Best practice:** Always compare like-to-like. Create separate baselines for each configuration.

## 10. Practical Workflow: Memory-Constrained CI

Here's how to set up canary testing on limited hardware.

In [None]:
# Create a memory-optimized config for CI
ci_config = """
name: dpo_ci_4bit
description: Memory-optimized DPO canary for CI with limited GPU memory

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 8           # Smaller LoRA rank to save memory
lora_alpha: 16
lora_dropout: 0.05

# 4-bit quantization for memory savings
load_in_4bit: true
load_in_8bit: false

training_type: dpo
max_steps: 50       # Shorter for faster CI
batch_size: 1       # Smaller batch size
gradient_accumulation_steps: 8  # Compensate with gradient accumulation
learning_rate: 5.0e-5
max_length: 128     # Shorter sequences
warmup_steps: 5

beta: 0.1
max_prompt_length: 32

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 256
seed: 42

output_dir: ./canary_output
metrics_warmup_steps: 5
"""

with open('configs/dpo_ci_4bit.yaml', 'w') as f:
    f.write(ci_config)

print("Created memory-optimized CI config: configs/dpo_ci_4bit.yaml")

In [None]:
# Run the memory-optimized config
!python -m canary.cli run configs/dpo_ci_4bit.yaml -o ./canary_output/ci_4bit

In [None]:
# Check memory usage of optimized config
ci_path = next(Path('./canary_output/ci_4bit').rglob('metrics.json'))

with open(ci_path) as f:
    ci_metrics = json.load(f)

print("Memory-Optimized CI Results")
print("=" * 40)
print(f"Peak memory: {ci_metrics['perf']['max_mem_mb']:.0f} MB")
print(f"Step time: {ci_metrics['perf']['step_time']['mean']:.4f}s")
print(f"Total duration: {ci_metrics['duration_seconds']:.1f}s")

print(f"\nMemory savings vs FP16: {(1 - ci_metrics['perf']['max_mem_mb']/fp16_mem)*100:.1f}%")

### CI Workflow Example

```yaml
# .github/workflows/canary.yml
name: Canary Tests

on: [pull_request]

jobs:
  canary:
    runs-on: [self-hosted, gpu-limited]  # 8GB GPU runner
    steps:
      - uses: actions/checkout@v4
      
      - name: Run 4-bit canary
        run: canary run configs/dpo_ci_4bit.yaml -o ./output
        
      - name: Compare to 4-bit baseline
        run: |
          METRICS=$(find ./output -name metrics.json)
          canary compare $METRICS baselines/limited_gpu/4bit/main.json
```

## 11. Summary

### Key Takeaways

1. **4-bit quantization** saves ~75% memory with slight speed impact
2. **8-bit quantization** saves ~50% memory with minimal speed impact
3. **Always compare like-to-like** - don't compare 4-bit runs to FP16 baselines
4. **Organize baselines by GPU and quantization** for proper comparisons
5. **Use `save-baseline`** command for proper baseline management
6. **Optimize configs for CI** - smaller LoRA, shorter sequences, smaller batches

### Quick Reference

```yaml
# Enable 4-bit quantization
load_in_4bit: true
load_in_8bit: false

# Enable 8-bit quantization
load_in_4bit: false
load_in_8bit: true

# Full precision (default)
load_in_4bit: false
load_in_8bit: false
```

```bash
# Save baseline
canary save-baseline ./output/metrics.json baselines/t4/4bit/main.json

# Compare (use matching baseline!)
canary compare ./current/metrics.json baselines/t4/4bit/main.json
```

### Next Steps

- [01_quickstart.ipynb](01_quickstart.ipynb) - Core workflow basics
- [07_ci_cd_integration.ipynb](07_ci_cd_integration.ipynb) - GitHub integration
- [08_configuration_and_thresholds.ipynb](08_configuration_and_thresholds.ipynb) - Advanced configuration