# Training Self-RAG Models

Train critic and generator models using QLoRA for Self-RAG system.

## Prerequisites

Before training:
1. ✅ Documents indexed (from notebook 02)
2. ✅ Training data prepared (from notebook 01)
3. ⚠️ Training requires significant compute (GPU recommended)

## Step 1: Generate Training Labels

Generate reflection token labels for Q&A data.

In [1]:
%%bash
# Generate labels using rule-based approach
uv run python -m src.training.generate_labels \
    --input ../data/samples/sample_qa_data.json \
    --output-dir ../data/training \
    --num-samples 10 && \
echo "✅ Labels generated!"

Generating labels: 100%|██████████| 10/10 [00:00<00:00, 45491.37it/s]



Labeled 10 examples
Saved to ../data/training/labeled_data.json
✅ Labels generated!


## Step 2: Train Critic Model

Train the critic model to predict reflection tokens.

In [2]:
%%bash
# Train critic (reduce epochs for testing)
uv run python -m src.training.train_critic_qlora \
    --config ../configs/critic_config.yaml && \
echo "✅ Critic model trained!"

CRITIC MODEL TRAINING

Configuration loaded from: ../configs/critic_config.yaml
Resolved training_data_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training
Resolved output_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/critic_lora

Device Selection: MPS
✓ Mac GPU (MPS) available and will be used
  PyTorch MPS backend: True


1. Loading tokenizer...
   Added 18 reflection tokens to vocabulary

2. Loading base model...
   Note: 4-bit quantization disabled for macOS compatibility

3. Preparing model for LoRA training...
trainable params: 2,162,688 || all params: 495,968,768 || trainable%: 0.4361

4. Loading and formatting training data...
Loading training data from /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training/labeled_data.json
Loaded 10 examples


Formatting examples: 100%|██████████| 10/10 [00:00<00:00, 10462.22it/s]


   Created 40 training examples
   Train: 36, Validation: 4

5. Tokenizing datasets...


Tokenizing train: 100%|██████████| 36/36 [00:00<00:00, 378.87 examples/s]
Tokenizing validation: 100%|██████████| 4/4 [00:00<00:00, 2126.39 examples/s]



6. Setting up training...

7. Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
                                             

{'train_runtime': 103.2187, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.087, 'train_loss': 2.247092776828342, 'epoch': 3.0}


100%|██████████| 9/9 [01:43<00:00,  9.95s/it]█| 9/9 [01:43<00:00, 11.47s/it]



8. Saving final model...





TRAINING COMPLETE!
Model saved to: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/critic_lora/final
✅ Critic model trained!


## Step 3: Train Generator Model

Train the generator model with augmented data.

In [3]:
%%bash
# Train generator with critic weights
uv run python -m src.training.train_generator_qlora \
    --config ../configs/generator_config.yaml \
    --critic-weights ../models/critic_lora/final && \
echo "✅ Generator model trained!"

GENERATOR MODEL TRAINING

Configuration loaded from: ../configs/generator_config.yaml
Resolved training_data_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training
Resolved output_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/generator_lora

Device Selection: MPS
✓ Mac GPU (MPS) available and will be used
  PyTorch MPS backend: True


1. Loading tokenizer...
   Added 18 reflection tokens to vocabulary

2. Loading critic model from ../models/critic_lora/final...
Loading model: Qwen/Qwen2.5-0.5B-Instruct
Loading LoRA weights from ../models/critic_lora/final
Model loaded successfully
   Critic model loaded

3. Loading training data...
Loading training data from /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training/labeled_data.json
Loaded 10 examples
Augmenting data with critic model predictions...


  0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 10/10 [00:14<00:00,  1.40s/it]
Formatting examples: 100%|██████████| 10/10 [00:00<00:00, 8005.92it/s]


   Created 10 training examples
   Train: 9, Validation: 1

4. Loading base model...
   Note: 4-bit quantization disabled for macOS compatibility

5. Preparing model for LoRA training...
trainable params: 8,798,208 || all params: 502,604,288 || trainable%: 1.7505

6. Tokenizing datasets...


Tokenizing train: 100%|██████████| 9/9 [00:00<00:00, 1164.15 examples/s]
Tokenizing validation: 100%|██████████| 1/1 [00:00<00:00, 541.34 examples/s]



7. Setting up training...

8. Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
                                             

{'train_runtime': 75.8859, 'train_samples_per_second': 0.356, 'train_steps_per_second': 0.04, 'train_loss': 2.025380770365397, 'epoch': 3.0}


100%|██████████| 3/3 [01:15<00:00, 25.29s/it]



9. Saving final model...





TRAINING COMPLETE!
Model saved to: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/generator_lora/final
✅ Generator model trained!


## Step 4: Test Trained Models

Quick test of the trained Self-RAG system.

In [4]:
import sys
sys.path.append('..')

from src.self_rag.inference import load_pipeline_from_config

# Load complete pipeline
pipeline = load_pipeline_from_config(
    retrieval_config_path='../configs/retrieval_config.yaml',
    generator_config_path='../configs/generator_config.yaml',
    retriever_index_dir='../data/embeddings',
    generator_weights_path='../models/generator_lora/final',
    critic_weights_path='../models/critic_lora/final',
)

print("✅ Pipeline loaded!")

Loading Self-RAG Pipeline...

1. Loading retriever...
Loading embedding model: sentence-transformers/all-mpnet-base-v2
Model loaded on mps
Embedding dimension: 768
   Loading index from ../data/embeddings
Using CPU index
Created IndexFlatIP index with dimension 768
Index loaded from ../data/embeddings/faiss_index.faiss
Total documents in index: 10
Documents loaded from ../data/embeddings/documents.pkl
   Index loaded: 10 documents

2. Loading generator...
Loading generator model: Qwen/Qwen2.5-0.5B-Instruct
Loading LoRA weights from ../models/generator_lora/final
Generator model loaded successfully
MPS cache cleared

3. Loading critic model for reflection tokens...
Loading model: Qwen/Qwen2.5-0.5B-Instruct
Loading LoRA weights from ../models/critic_lora/final
Model loaded successfully
MPS cache cleared
   Critic model loaded successfully

Pipeline loaded successfully!
✅ Pipeline loaded!


In [5]:
# Test question
question = "What are the elements of negligence?"

result = pipeline.answer_question(question)

print(f"Question: {question}\n")
print(f"Answer: {result['answer']}\n")
print(f"Reflection: {result['reflection']}\n")
print(f"Score: {result['score']:.2f}")

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Question: What are the elements of negligence?

Answer: Negligence is a legal concept that involves a duty of care owed by one person to another. The duty of care requires that the defendant must have taken reasonable steps to prevent harm from occurring, and if they fail to do so, the plaintiff can recover damages for their injuries.
The elements of negligence include:
- Duty of care - A person owes a duty of care to others to avoid causing them harm.
- Breach of duty - The defendant fails to take reasonable steps to prevent harm from occurring.
- Damages - The plaintiff must be able to prove that they suffered actual harm as a result of the defendant's breach of duty.
- Proximate cause - The defendant must have caused the harm in question proximately (directly) through their actions.
- Contributory negligence - If the plaintiff was negligent themselves, they may be held responsible for any resulting damage.
- Vicarious liability - If the defendant is acting within the scope of their 

## Training Tips

### For CPU Training:
- Reduce `per_device_train_batch_size` to 1-2
- Increase `gradient_accumulation_steps`
- Reduce `num_train_epochs` to 1 for testing
- Use smaller models if available

### For GPU Training:
- Use larger batch sizes (4-8)
- Enable `fp16` or `bf16` in config
- Monitor GPU memory usage

### Monitoring:
- Check `models/*/logs/` for TensorBoard logs
- Watch training loss decrease
- Save checkpoints frequently

## Summary

Training complete!
- ✅ Generated training labels
- ✅ Trained critic model
- ✅ Trained generator model
- ✅ Tested Self-RAG pipeline

**Next:** Proceed to `04_evaluation.ipynb` to evaluate performance