# Complete Guide to LoRA & QLoRA Fine-tuning (2025)

## Practical Implementation: MITRE ATT&CK TTP Mapping

This comprehensive notebook covers:
1. **PEFT Methods** - LoRA, QLoRA, DoRA, AdaLoRA
2. **Quantization** - 4-bit, 8-bit, NF4, bitsandbytes
3. **Mixed Precision Training** - FP16, BF16, FP8
4. **Practical Fine-tuning** - Qwen 1.8B on MITRE TTP dataset
5. **SOTA Techniques** - Flash Attention, Paged Optimizers, Gradient Checkpointing

**Model:** Qwen2.5-1.5B (state-of-the-art 1B+ model)

**Task:** Map cybersecurity threat intelligence to MITRE ATT&CK TTPs

---
## Architecture Overview: LoRA vs QLoRA

### LoRA (Low-Rank Adaptation) Architecture

```mermaid
graph TD
    A[Input X] --> B[Frozen Pre-trained Weight W]
    A --> C[LoRA Low-Rank Matrices]
    
    C --> D[Matrix A: d×r]
    C --> E[Matrix B: r×d]
    
    D --> F["BA product (rank r)"]
    E --> F
    
    B --> G["W·x (frozen)"]
    F --> H["BA·x (trainable)"]
    
    G --> I["Output: W·x + α/r·BA·x"]
    H --> I
    
    style B fill:#ff9999
    style D fill:#99ff99
    style E fill:#99ff99
    style I fill:#9999ff
```

**Key Concepts:**
- **W**: Frozen pre-trained weights (d × d)
- **A**: Trainable low-rank matrix (d × r)
- **B**: Trainable low-rank matrix (r × d)
- **r**: Rank (typically 8, 16, 32, 64)
- **α**: Scaling factor

**Formula:** `h = W₀x + (α/r)·B·A·x`

**Parameter Reduction:**
- Original: d × d parameters
- LoRA: 2 × d × r parameters
- Example: d=4096, r=16 → 99.8% reduction!

### QLoRA Architecture

```mermaid
graph TB
    A["Original Model (FP32/FP16)"] --> B["4-bit Quantization (NF4)"]
    
    B --> C["Quantized Weights (4-bit)"]
    C --> D["Double Quantization"]
    D --> E["Quantized Constants"]
    
    C --> F["Dequantize to BF16"]
    
    F --> G["Forward Pass"]
    G --> H["LoRA Adapters (BF16)"]
    
    H --> I["Gradient Computation"]
    I --> J["Paged Optimizer"]
    
    J --> K["Update LoRA Only"]
    K --> H
    
    style C fill:#ff9999
    style H fill:#99ff99
    style J fill:#9999ff
```

**QLoRA Innovations:**
1. **4-bit NormalFloat (NF4)** - Information-theoretically optimal for normal distributions
2. **Double Quantization** - Quantize the quantization constants
3. **Paged Optimizers** - Handle memory spikes using CPU-GPU paging

**Memory Savings:**
- 7B model: ~28GB (FP32) → ~3.5GB (4-bit QLoRA)
- 13B model: ~52GB (FP32) → ~6.5GB (4-bit QLoRA)
- **Can fine-tune 13B on 24GB GPU!**

---
## PEFT Methods Comparison (2025)

```mermaid
graph LR
    A[PEFT Methods] --> B[Adapter-based]
    A --> C[Low-Rank]
    A --> D[Prompt-based]
    A --> E[Sparse]
    
    B --> B1[Adapters]
    B --> B2[Parallel Adapters]
    
    C --> C1[LoRA]
    C --> C2[QLoRA]
    C --> C3[AdaLoRA]
    C --> C4[DoRA]
    
    D --> D1[Prefix Tuning]
    D --> D2[P-Tuning]
    D --> D3[Prompt Tuning]
    
    E --> E1[IA³]
    E --> E2[BitFit]
    
    style C1 fill:#90EE90
    style C2 fill:#FFD700
    style C3 fill:#87CEEB
    style C4 fill:#FFA500
```

| Method | Trainable Params | Memory | Speed | Best For |
|--------|-----------------|--------|-------|----------|
| **LoRA** | 0.1-1% | Medium | Fast | General fine-tuning |
| **QLoRA** | 0.1-1% | **Very Low** | Fast | Memory-constrained |
| **AdaLoRA** | 0.05-0.5% | Medium | Medium | Adaptive rank allocation |
| **DoRA** | 0.05-0.5% | Medium | Medium | More robust hyperparameter tuning |
| **IA³** | 0.01% | **Lowest** | **Fastest** | Extremely low resource |
| **Prefix Tuning** | 0.1% | Low | Medium | Few-shot learning |
| **Full Fine-tuning** | 100% | **Highest** | Slow | Maximum performance |

**Key Insights:**
- **DoRA (Weight-Decomposed Low-Rank Adaptation)**: Separates magnitude and direction of weight updates, more robust to hyperparameter changes than LoRA
- **QLoRA Paper Finding**: Very little difference between rank 8 and 256 when LoRA applied to all layers
- **Target Modules**: Apply LoRA to both Attention AND MLP layers for best performance

---
## Setup and Installation

In [None]:
# Install required packages
!pip install -q transformers>=4.44.0
!pip install -q peft>=0.12.0
!pip install -q accelerate>=0.33.0
!pip install -q bitsandbytes>=0.43.0
!pip install -q datasets>=2.20.0
!pip install -q trl>=0.9.0
!pip install -q flash-attn --no-build-isolation
!pip install -q scipy
!pip install -q wandb

In [None]:
import torch
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    AdaLoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
from datasets import Dataset, load_dataset
import json
import numpy as np
from typing import Dict, List
import pandas as pd

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

---
## Quantization Deep Dive

### Quantization Comparison

```mermaid
graph TD
    A[Precision Types] --> B[FP32: 32-bit float]
    A --> C[FP16: 16-bit float]
    A --> D[BF16: 16-bit bfloat]
    A --> E[INT8: 8-bit integer]
    A --> F[NF4: 4-bit normalfloat]
    
    B --> B1["Range: ±3.4×10³⁸"]
    B --> B2["Memory: 4 bytes"]
    
    C --> C1["Range: ±65,504"]
    C --> C2["Memory: 2 bytes"]
    
    D --> D1["Range: ±3.4×10³⁸"]
    D --> D2["Memory: 2 bytes"]
    D --> D3["Better for ML"]
    
    E --> E1["Range: -128 to 127"]
    E --> E2["Memory: 1 byte"]
    
    F --> F1["Optimal for normal dist"]
    F --> F2["Memory: 0.5 bytes"]
    F --> F3["QLoRA innovation"]
    
    style B fill:#ff9999
    style D fill:#99ff99
    style F fill:#FFD700
```

In [None]:
# Quantization Configuration Options

# 1. 4-bit QLoRA Configuration (Most Memory Efficient)
qlora_4bit_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",              # NormalFloat 4-bit
    bnb_4bit_compute_dtype=torch.bfloat16,  # Computation dtype
    bnb_4bit_use_double_quant=True,         # Double quantization
)

# 2. 8-bit Configuration (Better Accuracy)
int8_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
)

# 3. No Quantization (Standard LoRA)
no_quant_config = None  # Load in FP16/BF16

print("Quantization configs defined:")
print("1. 4-bit NF4 (QLoRA)")
print("2. 8-bit INT8")
print("3. No quantization (FP16/BF16)")

### Quantization Impact

| Config | Memory (7B) | Speed | Accuracy | Use Case |
|--------|------------|-------|----------|----------|
| **FP32** | ~28 GB | 1x | 100% | Baseline |
| **FP16** | ~14 GB | 2-3x | 99.9% | Standard training |
| **BF16** | ~14 GB | 2-3x | 99.9% | **Recommended** |
| **INT8** | ~7 GB | 3-4x | 99.5% | Good balance |
| **NF4** | ~3.5 GB | 3-4x | 99%+ | **Memory critical** |

**Key Findings (2024-2025):**
- QLoRA with 8-bit can **converge faster** than BF16!
- 4-bit NF4 maintains performance comparable to 16-bit
- Double quantization saves additional 0.4 GB per 7B params

---
## MITRE ATT&CK Dataset Preparation

**Note on Production Datasets:**

This notebook uses 8 synthetic examples for demonstration. For production deployment, use larger, high-quality datasets:

### Recommended Public MITRE TTP Datasets:

1. **AttackQA Dataset** (Recommended - 25,335 samples)
   - Source: Research paper (arxiv.org/html/2411.01073v1)
   - Format: Question-answer pairs from MITRE ATT&CK knowledge base
   - Coverage: Comprehensive across all tactics and techniques
   - Best for: General MITRE knowledge and instruction tuning

2. **TRAM Dataset** (4,070 labeled sentences from 150 CTI reports)
   - Source: github.com/center-for-threat-informed-defense/tram
   - Coverage: 50 out of 625 techniques (limitation!)
   - Best for: Real-world CTI text understanding

3. **CTI-HAL Dataset** (2025 - Most recent)
   - Source: Research paper (arxiv.org/html/2504.05866v1)
   - Quality: Manually annotated with high inter-annotator agreement
   - Best for: High-quality CTI report mapping

4. **Adversary Emulation Library (AEL)**
   - Source: MITRE
   - Format: Concise campaign reports with technique IDs
   - Best for: Real attack scenarios

**Production Strategy:** Combine AttackQA (25K) + TRAM (4K) + Custom domain examples (500-1K) for robust 30K sample dataset.

In [None]:
# MITRE ATT&CK TTP Mapping Dataset
# Based on TRAM (Threat Report ATT&CK Mapper) format

# Sample training data structure
mitre_samples = [
    {
        "text": "The adversary used PowerShell to execute a malicious script that downloaded additional payloads from a command and control server.",
        "techniques": ["T1059.001", "T1105"],
        "technique_names": ["PowerShell", "Ingress Tool Transfer"],
        "tactics": ["Execution", "Command and Control"]
    },
    {
        "text": "Attackers gained initial access through a spear-phishing email containing a malicious attachment that exploited a vulnerability in Microsoft Office.",
        "techniques": ["T1566.001", "T1203"],
        "technique_names": ["Spearphishing Attachment", "Exploitation for Client Execution"],
        "tactics": ["Initial Access", "Execution"]
    },
    {
        "text": "The threat actor created a scheduled task to maintain persistence on the compromised system, executing a backdoor every hour.",
        "techniques": ["T1053.005"],
        "technique_names": ["Scheduled Task"],
        "tactics": ["Persistence", "Privilege Escalation"]
    },
    {
        "text": "Credential dumping was performed using Mimikatz to extract plaintext passwords and NTLM hashes from LSASS memory.",
        "techniques": ["T1003.001"],
        "technique_names": ["LSASS Memory"],
        "tactics": ["Credential Access"]
    },
    {
        "text": "The malware enumerated all running processes and network connections to identify security tools and establish situational awareness.",
        "techniques": ["T1057", "T1049"],
        "technique_names": ["Process Discovery", "System Network Connections Discovery"],
        "tactics": ["Discovery"]
    },
    {
        "text": "Sensitive files were collected and compressed into an archive before exfiltration over the encrypted C2 channel.",
        "techniques": ["T1560", "T1041"],
        "technique_names": ["Archive Collected Data", "Exfiltration Over C2 Channel"],
        "tactics": ["Collection", "Exfiltration"]
    },
    {
        "text": "Ransomware was deployed across the network using PsExec to encrypt files with AES-256 encryption.",
        "techniques": ["T1486", "T1021.002"],
        "technique_names": ["Data Encrypted for Impact", "SMB/Windows Admin Shares"],
        "tactics": ["Impact", "Lateral Movement"]
    },
    {
        "text": "The attacker disabled Windows Defender and tampered with event logs to evade detection and hide their activities.",
        "techniques": ["T1562.001", "T1070.001"],
        "technique_names": ["Disable or Modify Tools", "Clear Windows Event Logs"],
        "tactics": ["Defense Evasion"]
    },
]

print(f"Sample dataset size: {len(mitre_samples)} examples")
print("\nExample:")
print(json.dumps(mitre_samples[0], indent=2))

In [None]:
# Create instruction-following format for LLM fine-tuning

def format_mitre_example(example: Dict) -> str:
    """
    Format MITRE example for instruction fine-tuning
    """
    instruction = "Analyze the following cybersecurity incident description and identify the MITRE ATT&CK techniques (TTPs) used by the adversary."
    
    # Format techniques
    techniques_str = ", ".join([
        f"{tech} ({name})" 
        for tech, name in zip(example['techniques'], example['technique_names'])
    ])
    
    tactics_str = ", ".join(example['tactics'])
    
    formatted = f"""<|im_start|>system
You are a cybersecurity expert trained in MITRE ATT&CK framework. Analyze threat intelligence and map it to specific TTPs.<|im_end|>
<|im_start|>user
{instruction}

Incident Description:
{example['text']}<|im_end|>
<|im_start|>assistant
Based on the incident description, the following MITRE ATT&CK techniques were identified:

**Techniques:** {techniques_str}

**Tactics:** {tactics_str}

**Analysis:** The adversary employed these techniques to achieve their objectives, demonstrating a multi-stage attack pattern across the cyber kill chain.<|im_end|>"""
    
    return formatted

# Format dataset
formatted_dataset = [
    {"text": format_mitre_example(example)}
    for example in mitre_samples
]

print("Formatted example:")
print(formatted_dataset[0]["text"])

In [None]:
# Create HuggingFace Dataset
from datasets import Dataset

train_dataset = Dataset.from_list(formatted_dataset)

print(f"Dataset created with {len(train_dataset)} examples")
print(f"\nDataset features: {train_dataset.features}")

---
## Model Loading with Different Configurations

In [None]:
# Model selection: Qwen2.5-1.5B (state-of-the-art 1B model)
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
)

# Set pad token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded: {MODEL_NAME}")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"Pad token: {tokenizer.pad_token}")
print(f"EOS token: {tokenizer.eos_token}")

In [None]:
# Load model with 4-bit quantization (QLoRA)

# Note on Flash Attention 2:
# - Requires Ampere/Ada/Hopper GPU architecture (RTX 30xx/40xx, A100, H100)
# - Not available on older GPUs (V100, T4, GTX series)
# - Will fall back to standard attention if unavailable
# - Remove attn_implementation if you encounter errors

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=qlora_4bit_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",  # Use Flash Attention 2 (2-4x speedup)
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

print(f"\nModel loaded with 4-bit quantization")
print(f"Device map: {model.hf_device_map}")

# Calculate trainable parameters
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || "
        f"all params: {all_param:,} || "
        f"trainable%: {100 * trainable_params / all_param:.4f}"
    )

print("\nBefore adding LoRA:")
print_trainable_parameters(model)

---
## LoRA Configuration Options

In [None]:
# Standard LoRA Configuration
lora_config = LoraConfig(
    r=16,                                    # Rank
    lora_alpha=16,                           # Scaling factor (alpha=r is standard, alpha=r/2 for stronger effect)
    target_modules=[                         # Which modules to apply LoRA
        "q_proj",        # Query projection
        "k_proj",        # Key projection
        "v_proj",        # Value projection
        "o_proj",        # Output projection
        "gate_proj",     # MLP gate projection
        "up_proj",       # MLP up projection
        "down_proj",     # MLP down projection
    ],
    lora_dropout=0.05,                       # Dropout for LoRA layers
    bias="none",                             # Bias strategy
    task_type=TaskType.CAUSAL_LM,            # Task type
    inference_mode=False,                     # Training mode
)

print("LoRA Configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Alpha/r ratio: {lora_config.lora_alpha / lora_config.r}")
print(f"  Target modules: {lora_config.target_modules}")
print(f"  Dropout: {lora_config.lora_dropout}")

In [None]:
# Alternative: AdaLoRA Configuration (Adaptive Rank)
adalora_config = AdaLoraConfig(
    r=16,
    lora_alpha=16,                           # Fixed: alpha=r (was 32)
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # AdaLoRA specific parameters
    target_r=8,                              # Target average rank
    init_r=12,                               # Initial rank
    tinit=0,                                 # Start of rank adaptation
    tfinal=1000,                             # End of rank adaptation
    deltaT=10,                               # Rank update frequency
)

print("\nAdaLoRA Configuration (adaptive rank allocation):")
print(f"  Initial rank: {adalora_config.init_r}")
print(f"  Target rank: {adalora_config.target_r}")

In [None]:
# Apply LoRA to model
model = get_peft_model(model, lora_config)

print("\nAfter adding LoRA:")
print_trainable_parameters(model)

# Print model structure
print("\nLoRA adapter structure:")
for name, module in model.named_modules():
    if "lora" in name.lower():
        print(f"  {name}: {type(module).__name__}")

---
## Training Configuration

### Mixed Precision Training

```mermaid
graph LR
    A[Mixed Precision Training] --> B[FP16]
    A --> C[BF16]
    A --> D[FP8]
    
    B --> B1["Standard mixed precision"]
    B --> B2["Loss scaling required"]
    B --> B3["Range: ±65,504"]
    
    C --> C1["Better for ML/LLMs"]
    C --> C2["No loss scaling"]
    C --> C3["Range: ±3.4×10³⁸"]
    C --> C4["Recommended"]
    
    D --> D1["Newest (2024+)"]
    D --> D2["H100/A100 GPUs"]
    D --> D3["2x faster"]
    
    style C fill:#90EE90
    style C4 fill:#FFD700
```

In [None]:
# Training arguments with SOTA optimizations

training_args = TrainingArguments(
    # Output
    output_dir="./qwen-mitre-qlora",
    run_name="qwen-1.5b-mitre-qlora",
    
    # Training hyperparameters
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,          # Effective batch size = 16
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    
    # Mixed precision
    bf16=True,                               # Use BF16 (recommended)
    tf32=True,                               # Use TF32 on Ampere GPUs
    
    # Optimization
    optim="paged_adamw_8bit",               # Paged optimizer for QLoRA
    gradient_checkpointing=True,             # Save memory
    max_grad_norm=1.0,                       # Gradient clipping
    
    # Logging
    logging_steps=10,
    logging_dir="./logs",
    report_to="none",                        # Change to "wandb" for W&B logging
    
    # Saving
    save_strategy="epoch",
    save_total_limit=2,
    
    # Evaluation
    eval_strategy="no",
    
    # Performance
    dataloader_num_workers=4,
    group_by_length=True,                    # Group similar lengths
    
    # Reproducibility
    seed=42,
)

print("Training Configuration:")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Optimizer: {training_args.optim}")
print(f"  Precision: BF16={training_args.bf16}, TF32={training_args.tf32}")
print(f"  Gradient checkpointing: {training_args.gradient_checkpointing}")

### Optimizer Comparison

| Optimizer | Memory | Speed | Convergence | Use Case |
|-----------|--------|-------|-------------|----------|
| **AdamW** | High | Fast | Good | Standard |
| **AdamW 8-bit** | Medium | Fast | Good | Memory saving |
| **Paged AdamW 8-bit** | **Low** | Fast | Good | **QLoRA** |
| **AdaFactor** | Very Low | Medium | Variable | Extreme memory constraints |
| **Lion** | Low | **Fastest** | Good | New (2024+) |
| **SGD** | Lowest | Slow | Poor | Not recommended for LLMs |

In [None]:
# Tokenize dataset

def tokenize_function(examples):
    """Tokenize text examples"""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length",
        return_tensors="pt",
    )

tokenized_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

print(f"Tokenized dataset: {tokenized_dataset}")
print(f"\nExample tokenized sequence length: {len(tokenized_dataset[0]['input_ids'])}")

In [None]:
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked LM
)

print("Data collator created for causal language modeling")

---
## Training

In [None]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

print("Trainer initialized")
print(f"\nTraining will start with:")
print(f"  - {len(tokenized_dataset)} training examples")
print(f"  - {training_args.num_train_epochs} epochs")
print(f"  - Batch size: {training_args.per_device_train_batch_size}")
print(f"  - Gradient accumulation: {training_args.gradient_accumulation_steps}")

In [None]:
# Start training
print("Starting training...\n")

# Note: This is a demo with small dataset
# For production, use larger MITRE TTP dataset from:
# - github.com/tumeteor/mitre-ttp-mapping
# - TRAM annotated dataset

trainer.train()

print("\nTraining completed!")

In [None]:
# Save model
output_dir = "./qwen-mitre-qlora-final"
trainer.save_model(output_dir)

print(f"Model saved to {output_dir}")
print("\nSaved files:")
!ls -lh {output_dir}

---
## Inference and Evaluation

In [None]:
# Test inference

test_prompt = """<|im_start|>system
You are a cybersecurity expert trained in MITRE ATT&CK framework. Analyze threat intelligence and map it to specific TTPs.<|im_end|>
<|im_start|>user
Analyze the following cybersecurity incident description and identify the MITRE ATT&CK techniques (TTPs) used by the adversary.

Incident Description:
The attacker established persistence by modifying the Windows Registry Run key to execute a malicious DLL every time the user logs in. They also used WMI event subscriptions as a backup persistence mechanism.<|im_end|>
<|im_start|>assistant
"""

inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=False)
print("Generated Response:")
print(response)

---
## Advanced Topics

### 1. Flash Attention 2

```python
# Flash Attention provides 2-4x speedup
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",  # Enable Flash Attention
    ...
)
```

**Benefits:**
- 2-4x faster attention computation
- Lower memory usage
- Exact attention (not approximate)

### 2. Gradient Checkpointing

```python
# Trade computation for memory
model.gradient_checkpointing_enable()
```

**Impact:**
- Reduces memory by ~40%
- Increases training time by ~20%
- Essential for large batch sizes

### 3. DeepSpeed Integration

```python
# For multi-GPU training
training_args = TrainingArguments(
    ...
    deepspeed="ds_config_zero3.json",
)
```

**ZeRO Stages:**
- ZeRO-1: Optimizer state partitioning
- ZeRO-2: + Gradient partitioning
- ZeRO-3: + Parameter partitioning (most memory efficient)

---
## Memory Optimization Summary

### Memory Reduction Techniques

| Technique | Memory Saved | Speed Impact | Accuracy Impact |
|-----------|-------------|--------------|----------------|
| **4-bit Quantization** | 75% | None | <1% |
| **8-bit Quantization** | 50% | None | <0.1% |
| **Gradient Checkpointing** | 40% | -20% | None |
| **Flash Attention 2** | 20% | +200% | None |
| **LoRA (r=16)** | Minimal | +10% | Task-dependent |
| **Paged Optimizer** | 10% | -5% | None |

### Combined Impact (QLoRA)

Combining all techniques:
- **Total memory reduction: ~85%**
- **Speed: Comparable or faster**
- **Accuracy: 99%+ of full fine-tuning**

**Example: 7B Model**
- Full fine-tuning: ~60 GB (A100 required)
- QLoRA: ~9 GB (RTX 3090/4090 sufficient!)

### Practical Recommendations

**For 24GB GPU (RTX 3090/4090):**
- Use 4-bit QLoRA
- Max model size: ~13B parameters
- Batch size: 4-8 with gradient accumulation

**For 16GB GPU (RTX 4080):**
- Use 4-bit QLoRA
- Max model size: ~7B parameters
- Batch size: 2-4 with gradient accumulation

**For 8GB GPU (RTX 3070):**
- Use 4-bit QLoRA
- Max model size: ~3B parameters
- Batch size: 1 with gradient accumulation

---
## Rank Selection Guidelines

### LoRA Rank Impact

```mermaid
graph LR
    A[LoRA Rank r] --> B[r=4]
    A --> C[r=8]
    A --> D[r=16]
    A --> E[r=32]
    A --> F[r=64]
    
    B --> B1["Params: 0.01%"]
    B --> B2["Fast training"]
    B --> B3["Limited capacity"]
    
    C --> C1["Params: 0.02%"]
    C --> C2["Good for simple tasks"]
    
    D --> D1["Params: 0.05%"]
    D --> D2["Recommended default"]
    D --> D3["Good balance"]
    
    E --> E1["Params: 0.1%"]
    E --> E2["Complex tasks"]
    
    F --> F1["Params: 0.2%"]
    F --> F2["Maximum capacity"]
    F --> F3["Risk of overfitting"]
    
    style D fill:#90EE90
    style D2 fill:#FFD700
```

### Rank Selection Guide

| Task Complexity | Dataset Size | Recommended Rank |
|----------------|-------------|------------------|
| Simple (classification) | <1K samples | r=4-8 |
| Medium (NER, QA) | 1K-10K | r=8-16 |
| Complex (generation) | 10K-100K | r=16-32 |
| Very complex (chat) | 100K+ | r=32-64 |

**MITRE TTP Mapping:**
- Multi-label classification
- 600+ classes (hierarchical)
- **Recommended: r=16-32**

---
## Advanced Theoretical Concepts

### 1. LoRA Hyperparameter Deep Dive

#### Alpha Scaling Factor - Common Misconceptions

**Formula Reminder:** `h = W₀x + (α/r)·B·A·x`

**Key Research Finding (QLoRA Paper 2024):**
- Alpha at **r/2** (50% of rank) and **r/4** (25% of rank) achieved excellent results
- **Lower alpha relative to rank = STRONGER fine-tuning effect**
- **Higher alpha relative to rank = WEAKER fine-tuning effect**

**Recommended Alpha Values:**

| Rank (r) | Alpha | Ratio | Effect | Use Case |
|----------|-------|-------|--------|----------|
| 8 | 4 | 0.5 | Strong | Small datasets, aggressive adaptation |
| 8 | 8 | 1.0 | **Standard** | **Recommended default** |
| 8 | 16 | 2.0 | Weak | Large datasets, subtle changes |
| 16 | 8 | 0.5 | Strong | Domain shift |
| 16 | 16 | 1.0 | **Standard** | **Recommended default** |
| 16 | 32 | 2.0 | Weak | Minor refinement |

**Practical Guidelines:**
- Start with `alpha = r` (ratio = 1.0)
- If model not adapting enough, reduce alpha to `r/2`
- If model overfitting, increase alpha to `2*r`
- Contrary to popular belief, `alpha = 2*r` is NOT typical - it's conservative!

In [None]:
# Example: Testing different alpha values

def create_lora_config_with_alpha(r=16, alpha_ratio=1.0):
    """
    Create LoRA config with specified alpha ratio
    
    Args:
        r: Rank
        alpha_ratio: Ratio of alpha to rank (0.5, 1.0, or 2.0)
    """
    alpha = int(r * alpha_ratio)
    
    config = LoraConfig(
        r=r,
        lora_alpha=alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    
    print(f"Rank: {r}, Alpha: {alpha}, Ratio: {alpha_ratio} - Scaling factor: {alpha/r}")
    return config

# Standard configuration (recommended)
standard_config = create_lora_config_with_alpha(r=16, alpha_ratio=1.0)

# Strong fine-tuning (for difficult domain adaptation)
strong_config = create_lora_config_with_alpha(r=16, alpha_ratio=0.5)

# Conservative fine-tuning (for minor refinements)
conservative_config = create_lora_config_with_alpha(r=16, alpha_ratio=2.0)

---
## Key Takeaways (Updated with Latest Best Practices)

### LoRA vs QLoRA
- **LoRA**: Fast, efficient, 0.1-1% trainable parameters
- **QLoRA**: Same as LoRA + 75% memory reduction via 4-bit quantization
- **Choice**: Use QLoRA when memory-constrained, LoRA otherwise

### Critical Hyperparameter Insights

**Alpha Scaling (CORRECTED):**
- **NOT** `alpha = 2*r` as commonly stated
- **Standard**: `alpha = r` (ratio = 1.0)
- **Stronger effect**: `alpha = r/2` (ratio = 0.5) - for aggressive adaptation
- **Weaker effect**: `alpha = 2*r` (ratio = 2.0) - for conservative refinement
- **Research-backed**: QLoRA paper shows alpha=r/2 and alpha=r work excellently

**Rank Selection (IMPORTANT):**
- **Surprising finding**: Very little difference between r=8 and r=256 when applied to all layers
- **Effective rank** is often much lower than specified rank
- **Start with r=8**, only increase if performance plateaus
- **Don't waste compute**: r=16 is usually more than enough

**Target Modules:**
- Apply LoRA to **both Attention AND MLP layers** for best performance
- Targeting only attention saves memory but reduces capacity
- This notebook targets all 7 key modules (q, k, v, o, gate, up, down)

### Quantization
- **4-bit NF4**: Best memory efficiency (75% reduction), <1% accuracy loss
- **8-bit INT8**: Good balance (50% reduction), can converge faster than BF16
- **Impact**: Minimal accuracy loss, sometimes even better convergence
- **Double quantization**: Additional 0.4 GB savings per 7B params

### PEFT Methods (2025)
- **LoRA**: Most versatile, widely supported, proven
- **AdaLoRA**: Adaptive rank allocation, can achieve same performance with fewer params
- **DoRA**: Weight-Decomposed Low-Rank Adaptation, more robust to hyperparameters
- **QLoRA**: LoRA + 4-bit quantization, enables fine-tuning on consumer GPUs
- **IA³**: Ultra-efficient (0.01% params), but limited capacity

### Training Optimizations
- **Mixed Precision**: BF16 recommended for LLMs (no loss scaling needed, large range)
- **Flash Attention 2**: 2-4x speedup, requires Ampere/Ada/Hopper GPUs
- **Gradient Checkpointing**: 40% memory reduction, 20% time increase
- **Paged Optimizer**: Essential for QLoRA, handles memory spikes

### Best Practices for Production

1. **Always use train/validation split** (85/15 recommended)
2. **Enable early stopping** (patience=3 evaluations)
3. **Mix in general examples** (20%) to prevent catastrophic forgetting
4. **Use proper evaluation metrics** for multi-label classification
5. **Start with r=8, alpha=8** before trying larger ranks
6. **Merge adapters** for production deployment (faster inference)

### For MITRE TTP Mapping
- **Model**: Qwen 2.5-1.5B (32K context beats Gemma's 8K)
- **Config**: 4-bit QLoRA, r=16, alpha=16
- **Data**: AttackQA (25K) + TRAM (4K) + custom (1K) = 30K samples
- **Evaluation**: Exact match ratio, precision/recall, Jaccard similarity
- **Deployment**: Merge adapters, use vLLM or TGI

### Production Deployment
1. Fine-tune with QLoRA on consumer GPU (RTX 3090/4090)
2. Merge adapters into base model
3. Optionally quantize merged model to 4-bit/8-bit
4. Deploy with vLLM (10-20x faster) or TGI
5. Monitor performance drift, retrain with new MITRE versions

### Memory Requirements
- **24GB GPU**: Can fine-tune up to 13B models with 4-bit QLoRA
- **16GB GPU**: Can fine-tune up to 7B models with 4-bit QLoRA
- **8GB GPU**: Can fine-tune up to 3B models with 4-bit QLoRA
- Use gradient accumulation for effective larger batch sizes

### Common Mistakes to Avoid
1. Setting `alpha = 2*r` without understanding the effect (too conservative)
2. Using very high rank (r=256) without testing r=8 first (wasted compute)
3. No validation set (can't detect overfitting)
4. Training too long (catastrophic forgetting)
5. Not applying LoRA to MLP layers (limited capacity)
6. Ignoring Flash Attention warnings (falls back silently on old GPUs)

This notebook reflects 2024-2025 best practices based on latest research!

---
## Model Comparison: Qwen 2.5-1.5B vs Gemma 2-2B

### Why Qwen 2.5-1.5B Was Chosen for This Notebook

| Feature | Qwen 2.5-1.5B | Gemma 2-2B | Winner |
|---------|---------------|------------|--------|
| **Parameters** | 1.54B | 2.6B | Gemma (larger) |
| **Context Length** | **32K tokens** | 8K tokens | **Qwen** |
| **Training Data** | 18T tokens | ~6T tokens | **Qwen** |
| **Math/Coding** | Excellent | Good | **Qwen** |
| **License** | Apache 2.0 | Gemma Terms | **Qwen** (more permissive) |
| **Efficiency** | Better (smaller) | Needs more resources | **Qwen** |
| **Architecture** | Qwen 2.5 | Gemma 2 | Both modern |
| **Release** | Sept 2024 | Feb 2024 | Qwen (newer) |

### Performance on General Benchmarks

**Qwen 2.5-1.5B-Instruct:**
- MMLU: 61.4%
- HumanEval (coding): 37.8%
- GSM8K (math): 63.0%
- Overall: Strong all-around performance

**Gemma 2-2B-Instruct:**
- MMLU: ~55-58%
- HumanEval (coding): ~30-35%
- GSM8K (math): ~45-50%
- Overall: Good but slightly behind

### Cybersecurity-Specific Considerations

**For MITRE TTP Mapping:**

1. **Context Length is Critical**
   - CTI reports often 5K-20K tokens
   - Qwen's 32K context >> Gemma's 8K
   - **Verdict: Qwen wins for long documents**

2. **Structured Output Quality**
   - Both support JSON mode
   - Qwen slightly better at following formats
   - **Verdict: Slight edge to Qwen**

3. **Training Data Diversity**
   - Qwen: 18T tokens, multilingual
   - Gemma: Smaller dataset
   - **Verdict: Qwen has more diverse knowledge**

4. **Fine-tuning Efficiency**
   - Qwen: 1.54B params = faster training
   - Gemma: 2.6B params = 70% more memory
   - **Verdict: Qwen more efficient**

### When to Choose Gemma 2-2B Instead

**Consider Gemma if:**
- You're in Google ecosystem (Vertex AI, Colab)
- You need stronger safety filtering
- 8K context is sufficient for your use case
- You prefer Google's research lineage
- You want slightly more capacity (2.6B params)

**Stick with Qwen if:**
- Long CTI reports (>8K tokens)
- Need Apache 2.0 licensing
- Want better coding/structured output
- Prefer latest model (Sept 2024 vs Feb 2024)
- Need maximum efficiency

### Alternative Models Worth Considering

1. **Llama 3.2-3B-Instruct**
   - 3B params, 128K context
   - Strong general performance
   - Trade-off: Larger, slower

2. **Phi-4 (3.8B)**
   - Excellent reasoning
   - 16K context
   - Trade-off: Much larger

3. **SecurityLLM-8B** (Purpose-built for cybersecurity)
   - Source: arxiv.org/html/2504.21039
   - Pre-trained on security corpus
   - Trade-off: Less general purpose

### Final Recommendation

**For MITRE TTP Mapping: Stick with Qwen 2.5-1.5B**

Reasons:
- 32K context handles full CTI reports
- Better coding ability for structured JSON output
- More efficient fine-tuning (1.54B vs 2.6B)
- Newer model with latest training techniques
- Apache 2.0 license (no restrictions)

The 32K context window is the **decisive factor** - most CTI reports exceed 8K tokens when formatted with system prompts and examples.

### 3. Learning Rate Schedules Explained

**Why Cosine Schedule Works Well for Fine-tuning:**

```python
# Learning rate over time with different schedules
# Epoch:    1      2      3
# Cosine:  2e-4 → 1e-4 → 2e-5  (smooth decay)
# Linear:  2e-4 → 1.3e-4 → 6e-5  (linear decay)
# Constant: 2e-4 → 2e-4 → 2e-4  (no decay)
```

**Schedule Comparison:**

| Schedule | Best For | Pros | Cons |
|----------|----------|------|------|
| **Cosine** | Fine-tuning, instruction tuning | Smooth convergence, avoids sudden drops | Slightly slower initial learning |
| **Linear** | Short training runs | Simple, predictable | Can be too aggressive |
| **Constant + Warmup** | Continued pretraining | Maintains high LR longer | Risk of overshooting |
| **Inverse Sqrt** | Very long training | Proven for transformers | Not ideal for short fine-tuning |
| **Polynomial** | Custom decay rate | Flexible | Requires tuning |

**Warmup Phase Importance:**

Warmup gradually increases LR from 0 to target over first N steps:
- Prevents large gradient updates early
- Stabilizes training with quantized weights
- Especially important for QLoRA (4-bit weights)

```python
# Recommended warmup
warmup_ratio=0.1    # 10% of total steps
# For 1000 total steps: 100 warmup steps
# LR goes: 0 → 2e-5 → 2e-4 (over first 100 steps)
```

### 4. Catastrophic Forgetting in Fine-tuning

**The Problem:**

When fine-tuning on narrow domains (like MITRE TTPs), models can "forget" general capabilities:

```
Before fine-tuning: "What is Python?" → Detailed, accurate answer
After fine-tuning:  "What is Python?" → Tries to map to MITRE technique!
```

**Why it happens:**
- Small dataset pushes weights toward specific domain
- General knowledge connections weakened
- More severe with higher learning rates and longer training

**Mitigation Strategies:**

#### Strategy 1: Mix in General Examples (Recommended)

```python
# 80% domain-specific + 20% general instruction data
mitre_samples = load_mitre_dataset()      # 8,000 samples
general_samples = load_general_qa()       # 2,000 samples
combined = mitre_samples + general_samples

# This maintains general capability while adapting to domain
```

#### Strategy 2: Lower Learning Rate

```python
# More conservative adaptation
learning_rate=5e-5  # Instead of 2e-4
# Takes longer but preserves more general knowledge
```

#### Strategy 3: Shorter Training

```python
# Don't overtrain
num_epochs=1-2  # Instead of 3-5
# Monitor when validation performance plateaus
```

#### Strategy 4: Regularization via LoRA Alpha

```python
# Higher alpha = more conservative
lora_alpha=32  # With r=16 (ratio=2.0)
# Weights change more gradually
```

**Evaluation:**

Always test on both domain-specific AND general benchmarks:

```python
# Domain performance
ttp_accuracy = evaluate_mitre_ttps(model)

# General capability (detect forgetting)
general_qa_accuracy = evaluate_general_qa(model)
coding_ability = evaluate_code_generation(model)
```

---
## Best Practices for Production Fine-tuning

### 1. Proper Train/Validation Split with Early Stopping

In [None]:
# Create proper train/validation split
from datasets import Dataset

# Combine all data
all_formatted_data = [
    {"text": format_mitre_example(example)}
    for example in mitre_samples
]

# Create dataset and split
dataset = Dataset.from_list(all_formatted_data)
split_dataset = dataset.train_test_split(test_size=0.15, seed=42)

train_dataset_split = split_dataset['train']
val_dataset_split = split_dataset['test']

print(f"Training samples: {len(train_dataset_split)}")
print(f"Validation samples: {len(val_dataset_split)}")
print(f"Validation ratio: {len(val_dataset_split) / len(dataset):.1%}")

# Tokenize both splits
train_tokenized = train_dataset_split.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

val_tokenized = val_dataset_split.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

print("\nDatasets ready for training with validation")

In [None]:
# Training with validation and early stopping

from transformers import EarlyStoppingCallback

training_args_with_validation = TrainingArguments(
    output_dir="./qwen-mitre-qlora-validated",
    run_name="qwen-mitre-with-validation",
    
    # Training
    num_train_epochs=10,                     # More epochs, but will stop early
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    
    # Mixed precision
    bf16=True,
    tf32=True,
    
    # Optimization
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    max_grad_norm=1.0,
    
    # Logging
    logging_steps=10,
    logging_dir="./logs",
    report_to="none",
    
    # Evaluation and early stopping
    eval_strategy="steps",                   # Evaluate every N steps
    eval_steps=50,                           # Evaluate every 50 steps
    save_strategy="steps",
    save_steps=50,
    save_total_limit=3,                      # Keep only best 3 checkpoints
    load_best_model_at_end=True,             # Load best model when training ends
    metric_for_best_model="eval_loss",       # Use validation loss as metric
    greater_is_better=False,                 # Lower loss is better
    
    # Performance
    dataloader_num_workers=4,
    group_by_length=True,
    seed=42,
)

# Trainer with early stopping
trainer_with_validation = Trainer(
    model=model,
    args=training_args_with_validation,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,              # Add validation dataset
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]  # Stop if no improvement for 3 evals
)

print("Trainer configured with validation and early stopping")
print(f"Will stop training if validation loss doesn't improve for {3 * 50} steps")

### 2. Evaluation Metrics for Multi-label Classification (MITRE TTPs)

In [None]:
# Evaluation metrics for MITRE TTP mapping

from sklearn.metrics import (
    precision_recall_fscore_support,
    hamming_loss,
    accuracy_score,
    jaccard_score
)
import numpy as np

def evaluate_ttp_predictions(y_true, y_pred, technique_names=None):
    """
    Comprehensive evaluation for multi-label TTP classification
    
    Args:
        y_true: Ground truth binary matrix (n_samples, n_techniques)
        y_pred: Predicted binary matrix (n_samples, n_techniques)
        technique_names: Optional list of technique names for per-class metrics
    
    Returns:
        Dictionary of metrics
    """
    
    # Exact match ratio (all techniques must match)
    exact_match = accuracy_score(y_true, y_pred)
    
    # Micro-averaged metrics (treat each technique prediction equally)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='micro', zero_division=0
    )
    
    # Macro-averaged metrics (average per technique, then average)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro', zero_division=0
    )
    
    # Hamming loss (fraction of wrong labels)
    h_loss = hamming_loss(y_true, y_pred)
    
    # Jaccard similarity (intersection over union)
    jaccard = jaccard_score(y_true, y_pred, average='samples', zero_division=0)
    
    metrics = {
        'exact_match_ratio': exact_match,
        'precision_micro': precision_micro,
        'recall_micro': recall_micro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'hamming_loss': h_loss,
        'jaccard_similarity': jaccard,
    }
    
    # Per-technique metrics (if names provided)
    if technique_names:
        precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )
        
        print("\nPer-Technique Metrics:")
        print(f"{'Technique':<20} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}")
        print("-" * 65)
        for i, name in enumerate(technique_names):
            print(f"{name:<20} {precision_per_class[i]:>10.3f} {recall_per_class[i]:>10.3f} "
                  f"{f1_per_class[i]:>10.3f} {int(support[i]):>10}")
    
    return metrics

# Example usage
print("Evaluation Metrics for MITRE TTP Mapping:\n")
print("1. Exact Match Ratio: % of samples where ALL techniques are correctly predicted")
print("2. Precision (Micro): Overall precision across all technique predictions")
print("3. Recall (Micro): Overall recall across all technique predictions")
print("4. F1 (Micro): Harmonic mean of precision and recall")
print("5. Hamming Loss: Fraction of incorrect technique labels")
print("6. Jaccard Similarity: Average IoU between predicted and true technique sets")

### 3. Merging LoRA Adapters for Production Deployment

In [None]:
# Merging LoRA adapters into base model for faster inference

from peft import PeftModel

# Option 1: Merge adapters (creates single model, faster inference)
# Load base model
base_model_for_merge = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)

# Load LoRA adapters
peft_model = PeftModel.from_pretrained(
    base_model_for_merge,
    "./qwen-mitre-qlora-final"  # Path to saved adapters
)

# Merge adapters into base weights
merged_model = peft_model.merge_and_unload()

# Save merged model
merged_model.save_pretrained("./qwen-mitre-merged")
tokenizer.save_pretrained("./qwen-mitre-merged")

print("Merged model saved to ./qwen-mitre-merged")
print("\nBenefits of merging:")
print("- Faster inference (no adapter overhead)")
print("- Simpler deployment (single model file)")
print("- Compatible with vLLM, TGI, etc.")
print("\nTrade-off:")
print("- Larger model size (full precision weights)")
print("- Can't switch adapters dynamically")

### 4. Hyperparameter Search Strategy

In [None]:
# Recommended hyperparameter search grid for LoRA/QLoRA

hyperparameter_search_grid = {
    # LoRA parameters
    'rank': [8, 16, 32],                    # Start with 8!
    'lora_alpha_ratio': [0.5, 1.0, 2.0],    # alpha = r * ratio
    'lora_dropout': [0.0, 0.05, 0.1],
    
    # Training parameters
    'learning_rate': [5e-5, 1e-4, 2e-4],
    'num_epochs': [1, 2, 3],
    'batch_size': [4, 8],                   # Adjust for GPU memory
}

# RECOMMENDED SEARCH STRATEGY:

print("Phase 1: Quick baseline")
print("-" * 50)
baseline = {
    'rank': 8,
    'lora_alpha': 8,                        # alpha = r
    'lora_dropout': 0.05,
    'learning_rate': 2e-4,
    'num_epochs': 2,
    'batch_size': 4,
}
print(f"Baseline config: {baseline}")
print("Train for 2 epochs, evaluate performance\n")

print("Phase 2: Optimize rank (if needed)")
print("-" * 50)
print("Try ranks: [8, 16, 32]")
print("Keep other params from baseline")
print("Pick rank with best validation performance\n")

print("Phase 3: Optimize learning rate")
print("-" * 50)
print("Try LRs: [5e-5, 1e-4, 2e-4]")
print("Use best rank from Phase 2")
print("Pick LR with best validation performance\n")

print("Phase 4: Fine-tune other params")
print("-" * 50)
print("Try alpha ratios: [0.5, 1.0] (if overfitting, try 2.0)")
print("Try dropout: [0.0, 0.05] (if overfitting, try 0.1)")
print("Try epochs: [1, 2, 3] (watch for overfitting)\n")

print("IMPORTANT: Don't optimize all at once!")
print("- Total search space: 3 × 3 × 3 × 3 × 3 × 2 = 486 combinations")
print("- Sequential search: ~10-15 runs to find good config")
print("- Use validation set to avoid overfitting to test set")

---
## Production Checklist

### Before Fine-tuning
- [ ] **Data Quality**: Clean, diverse, representative dataset
- [ ] **Data Size**: Minimum 100-1000 examples per task
- [ ] **Format**: Consistent instruction format
- [ ] **Validation Set**: 10-20% held out for evaluation
- [ ] **Baseline**: Test base model performance first

### Configuration
- [ ] **Quantization**: Choose based on GPU memory
- [ ] **Rank**: Start with r=16, adjust based on results
- [ ] **Learning Rate**: 2e-4 to 5e-5 (lower for larger models)
- [ ] **Batch Size**: As large as memory allows
- [ ] **Epochs**: 3-5 (monitor for overfitting)

### During Training
- [ ] **Monitor Loss**: Should decrease smoothly
- [ ] **Check Gradients**: Watch for exploding/vanishing
- [ ] **Sample Outputs**: Generate periodically
- [ ] **Save Checkpoints**: Every epoch or N steps
- [ ] **Log Metrics**: Use W&B or TensorBoard

### After Training
- [ ] **Evaluate**: Test on held-out validation set
- [ ] **Compare**: Benchmark against base model
- [ ] **Test Edge Cases**: Adversarial examples
- [ ] **Merge Adapters**: Optional, for inference speed
- [ ] **Deploy**: Use inference optimization (vLLM, TGI)

### MITRE TTP Specific
- [ ] **Coverage**: Test all 14 tactics
- [ ] **Precision**: Minimize false positives
- [ ] **Recall**: Catch all techniques
- [ ] **Hierarchy**: Respect sub-technique relationships
- [ ] **Updates**: Plan for new MITRE versions

---
## Resources and References

### Papers
1. **LoRA**: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685)
2. **QLoRA**: [QLORA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
3. **DoRA**: [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353)
4. **AdaLoRA**: [Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/abs/2303.10512)

### MITRE ATT&CK Resources
1. **TRAM**: [Threat Report ATT&CK Mapper](https://github.com/center-for-threat-informed-defense/tram)
2. **Dataset**: [MITRE TTP Mapping Dataset](https://github.com/tumeteor/mitre-ttp-mapping)
3. **Framework**: [MITRE ATT&CK](https://attack.mitre.org/)

### Code Repositories
1. **Hugging Face PEFT**: [github.com/huggingface/peft](https://github.com/huggingface/peft)
2. **QLoRA**: [github.com/artidoro/qlora](https://github.com/artidoro/qlora)
3. **Transformers**: [github.com/huggingface/transformers](https://github.com/huggingface/transformers)

### Models
1. **Qwen 2.5**: [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct)
2. **Phi-4**: [microsoft/phi-4](https://huggingface.co/microsoft/phi-4)
3. **Gemma 3**: [google/gemma-3-1b](https://huggingface.co/google/gemma-3-1b)
4. **Llama 3.2**: [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)

---
## Key Takeaways

### LoRA vs QLoRA
- **LoRA**: Fast, efficient, 0.1-1% trainable parameters
- **QLoRA**: Same as LoRA + 75% memory reduction via 4-bit quantization
- **Choice**: Use QLoRA when memory-constrained, LoRA otherwise

### Quantization
- **4-bit NF4**: Best memory efficiency (75% reduction)
- **8-bit INT8**: Good balance (50% reduction, faster convergence)
- **Impact**: Minimal accuracy loss (<1%)

### PEFT Methods (2025)
- **LoRA**: Most versatile, widely supported
- **AdaLoRA**: Adaptive rank, better performance with fewer params
- **DoRA**: More stable training, separates magnitude/direction
- **IA³**: Ultra-efficient (0.01% params), limited capacity

### Training Optimizations
- **Mixed Precision**: BF16 recommended for LLMs
- **Flash Attention 2**: 2-4x speedup, lower memory
- **Gradient Checkpointing**: 40% memory reduction
- **Paged Optimizer**: Essential for QLoRA

### For MITRE TTP Mapping
- **Model**: Qwen 2.5 1.5B (best 1B+ model)
- **Config**: 4-bit QLoRA, r=16-32
- **Data**: TRAM dataset + custom examples
- **Evaluation**: Multi-label metrics, tactic coverage

### Production Deployment
- Fine-tune with QLoRA on consumer GPU
- Merge adapters for inference (optional)
- Use vLLM or TGI for serving
- Monitor drift, update with new MITRE versions