# Training Self-RAG Models

Train critic and generator models using QLoRA for Self-RAG system.

## Prerequisites

Before training:
1. ✅ Documents indexed (from notebook 02)
2. ✅ Training data prepared (from notebook 01)
3. ⚠️ Training requires significant compute (GPU recommended)

## Step 1: Generate Training Labels

Generate reflection token labels for Q&A data.

In [6]:
%%bash
# Generate labels using rule-based approach
uv run python -m src.training.generate_labels \
    --input ../data/samples/sample_qa_data.json \
    --output-dir ../data/training \
    --num-samples 10 && \
echo "✅ Labels generated!"

python(46831) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Generating labels: 100%|██████████| 10/10 [00:00<00:00, 29495.81it/s]



Labeled 10 examples
Saved to ../data/training/labeled_data.json
✅ Labels generated!


## Step 2: Train Critic Model

Train the critic model to predict reflection tokens.

In [7]:
%%bash
# Train critic (reduce epochs for testing)
uv run python -m src.training.train_critic_qlora \
    --config ../configs/critic_config.yaml && \
echo "✅ Critic model trained!"

python(46835) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


CRITIC MODEL TRAINING

Configuration loaded from: ../configs/critic_config.yaml
Resolved training_data_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training
Resolved output_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/critic_lora

Device Selection: MPS
✓ Mac GPU (MPS) available and will be used
  PyTorch MPS backend: True


1. Loading tokenizer...
   Added 18 reflection tokens to vocabulary

2. Loading base model...
   Note: 4-bit quantization disabled for macOS compatibility

3. Preparing model for LoRA training...
trainable params: 4,358,144 || all params: 1,547,683,840 || trainable%: 0.2816

4. Loading and formatting training data...
Loading training data from /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training/labeled_data.json
Loaded 10 examples


Formatting examples: 100%|██████████| 10/10 [00:00<00:00, 3107.58it/s]


   Created 40 training examples
   Train: 36, Validation: 4

5. Tokenizing datasets...


Tokenizing train: 100%|██████████| 36/36 [00:00<00:00, 2436.86 examples/s]
Tokenizing validation: 100%|██████████| 4/4 [00:00<00:00, 1769.19 examples/s]



6. Setting up training...

7. Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
100%|██████████| 9/9 [07:17<00:00, 39.90s/it]

{'train_runtime': 437.7311, 'train_samples_per_second': 0.247, 'train_steps_per_second': 0.021, 'train_loss': 2.1289632585313587, 'epoch': 3.0}


100%|██████████| 9/9 [07:17<00:00, 48.64s/it]



8. Saving final model...





TRAINING COMPLETE!
Model saved to: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/critic_lora/final
✅ Critic model trained!


## Step 3: Train Generator Model

Train the generator model with augmented data.

In [8]:
%%bash
# Train generator with critic weights
uv run python -m src.training.train_generator_qlora \
    --config ../configs/generator_config.yaml \
    --critic-weights ../models/critic_lora/final && \
echo "✅ Generator model trained!"

python(47727) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


GENERATOR MODEL TRAINING

Configuration loaded from: ../configs/generator_config.yaml
Resolved training_data_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training
Resolved output_dir: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/generator_lora

Device Selection: MPS
✓ Mac GPU (MPS) available and will be used
  PyTorch MPS backend: True


1. Loading tokenizer...
   Added 18 reflection tokens to vocabulary

2. Loading critic model from ../models/critic_lora/final...
Loading model: Qwen/Qwen2.5-1.5B-Instruct
Loading LoRA weights from ../models/critic_lora/final
Model loaded successfully
   Critic model loaded

3. Loading training data...
Loading training data from /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/data/training/labeled_data.json
Loaded 10 examples
Augmenting data with critic model predictions...


  0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 10/10 [00:36<00:00,  3.68s/it]
Formatting examples: 100%|██████████| 10/10 [00:00<00:00, 8633.81it/s]


   Created 10 training examples
   Train: 9, Validation: 1

4. Loading base model...
   Note: 4-bit quantization disabled for macOS compatibility

5. Preparing model for LoRA training...
trainable params: 18,464,768 || all params: 1,561,790,464 || trainable%: 1.1823

6. Tokenizing datasets...


Tokenizing train: 100%|██████████| 9/9 [00:00<00:00, 325.65 examples/s]
Tokenizing validation: 100%|██████████| 1/1 [00:00<00:00, 651.80 examples/s]



7. Setting up training...

8. Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
100%|██████████| 3/3 [04:18<00:0             0, 86.22s/it]

{'train_runtime': 258.7341, 'train_samples_per_second': 0.104, 'train_steps_per_second': 0.012, 'train_loss': 1.7339324951171875, 'epoch': 3.0}


100%|██████████| 3/3 [04:18<00:00, 86.24s/it]



9. Saving final model...





TRAINING COMPLETE!
Model saved to: /Users/marcuschang/Library/CloudStorage/OneDrive-Personal/桌面/UCSD/DSC261_Responsible_DS/models/generator_lora/final
✅ Generator model trained!


## Step 4: Test Trained Models

Quick test of the trained Self-RAG system.

In [1]:
import sys
sys.path.append('..')

from src.self_rag.inference import load_pipeline_from_config

# Load complete pipeline
pipeline = load_pipeline_from_config(
    retrieval_config_path='../configs/retrieval_config.yaml',
    generator_config_path='../configs/generator_config.yaml',
    retriever_index_dir='../data/embeddings',
    generator_weights_path='../models/generator_lora/final',
    critic_weights_path='../models/critic_lora/final',
)

print("✅ Pipeline loaded!")

Loading Self-RAG Pipeline...

1. Loading retriever...
Loading embedding model: sentence-transformers/all-mpnet-base-v2
Model loaded on mps
Embedding dimension: 768
   Loading index from ../data/embeddings
Using CPU index
Created IndexFlatIP index with dimension 768
Index loaded from ../data/embeddings/faiss_index.faiss
Total documents in index: 10
Documents loaded from ../data/embeddings/documents.pkl
   Index loaded: 10 documents

2. Loading generator...
Loading generator model: Qwen/Qwen2.5-1.5B-Instruct
Loading LoRA weights from ../models/generator_lora/final
Generator model loaded successfully

3. Loading critic model for reflection tokens...
   Continuing without critic - reflection tokens may be unavailable

Pipeline loaded successfully!
✅ Pipeline loaded!


In [2]:
# Test question
question = "What are the elements of negligence?"

result = pipeline.answer_question(question)

print(f"Question: {question}\n")
print(f"Answer: {result['answer']}\n")
print(f"Reflection: {result['reflection']}\n")
print(f"Score: {result['score']:.2f}")

Question: What are the elements of negligence?

Answer: The four elements of negligence are:
1. A duty to act.
2. A breach of that duty.
3. An injury or damage caused by the breach.
4. Proximate cause, which means that the injury was reasonably foreseeable from the defendant's conduct and proximately resulted from it.
Negligence is a legal concept that allows people who have not acted intentionally to be held responsible for harm they cause due to their failure to exercise reasonable care. In other words, if someone fails to take reasonable precautions to prevent harm, and as a result causes an injury, that person may be liable in tort (which means "wrongful") for that injury.

The first element of negligence is a duty to act. This requires that there be some

Reflection: {'retrieve': None, 'isrel': None, 'issup': None, 'isuse': None, 'intent': None}

Score: 1.00


## Training Tips

### For CPU Training:
- Reduce `per_device_train_batch_size` to 1-2
- Increase `gradient_accumulation_steps`
- Reduce `num_train_epochs` to 1 for testing
- Use smaller models if available

### For GPU Training:
- Use larger batch sizes (4-8)
- Enable `fp16` or `bf16` in config
- Monitor GPU memory usage

### Monitoring:
- Check `models/*/logs/` for TensorBoard logs
- Watch training loss decrease
- Save checkpoints frequently

## Summary

Training complete!
- ✅ Generated training labels
- ✅ Trained critic model
- ✅ Trained generator model
- ✅ Tested Self-RAG pipeline

**Next:** Proceed to `04_evaluation.ipynb` to evaluate performance