# Policy Rule Training Pipeline

This notebook combines dataset generation and model training for fine-tuning Qwen2.5-1.5B on Rego policy rules.

## Steps:
1. **Setup & Configuration** - Set paths and parameters
2. **Generate Dataset** - Parse Rego files and create training examples
3. **Validate Dataset** - Check dataset quality and statistics
4. **Prepare Training** - Load and tokenize data
5. **Train Model** - Fine-tune with LoRA
6. **Evaluate** - Check training results


## 1. Setup & Configuration


In [None]:
import json
import os
import sys
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict
import random
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType

# Set tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Add current directory to path to import from generate_dataset
sys.path.insert(0, str(Path.cwd()))


In [None]:
# Configuration
REPO_ROOT = Path.cwd().parent if Path.cwd().name == "qwen2.5_model" else Path.cwd()
POLICY_RELEASE_DIR = REPO_ROOT / "policy" / "release"
POLICY_LIB_DIR = REPO_ROOT / "policy" / "lib"
RELEASE_LIB_DIR = REPO_ROOT / "policy" / "release" / "lib"

# Dataset paths
TRAIN_PATH = REPO_ROOT / "qwen2.5_model" / "train.jsonl"
EVAL_PATH = REPO_ROOT / "qwen2.5_model" / "eval.jsonl"
DATASET_SUMMARY_PATH = REPO_ROOT / "qwen2.5_model" / "dataset_summary.json"

# Training configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = REPO_ROOT / "qwen2.5-rego-policy-lora"
MAX_SEQ_LEN = 1024
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 4
LEARNING_RATE = 5e-5
NUM_EPOCHS = 3
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Dataset generation settings
TRAIN_SPLIT = 0.9  # 90% train, 10% eval
MAX_TOKENS = 1024

print(f"Repository root: {REPO_ROOT}")
print(f"Policy release dir: {POLICY_RELEASE_DIR}")
print(f"Output directory: {OUTPUT_DIR}")


## 2. Generate Dataset

Import the dataset generation functions and run them.


In [None]:
# Import dataset generation functions
# Note: You may need to adjust the import path based on your setup
try:
    from generate_dataset import (
        parse_rego_files,
        generate_training_examples,
        validate_rego_code,
        split_train_eval,
        write_jsonl,
        RuleExample,
        RegoFile
    )
    print("✓ Dataset generation functions imported")
except ImportError:
    print("⚠ Could not import from generate_dataset.py")
    print("  Make sure you're running from the qwen2.5_model directory")
    print("  Or adjust the import path above")


In [None]:
# Parse all Rego files
print("Parsing Rego files...")
rego_files = parse_rego_files(POLICY_RELEASE_DIR)
print(f"✓ Parsed {len(rego_files)} Rego files")

# Show some statistics
total_rules = sum(len(f.rules) for f in rego_files)
print(f"  Total rules found: {total_rules}")

# Show packages
packages = set(f.package for f in rego_files)
print(f"  Packages: {len(packages)}")
print(f"  Sample packages: {list(packages)[:5]}")


In [None]:
# Generate training examples
print("Generating training examples...")
examples = []

for rego_file in rego_files:
    file_examples = generate_training_examples(rego_file, POLICY_LIB_DIR, RELEASE_LIB_DIR)
    examples.extend(file_examples)
    if len(examples) % 50 == 0:
        print(f"  Generated {len(examples)} examples...")

print(f"✓ Generated {len(examples)} total examples")

# Count by task type
task_types = defaultdict(int)
for ex in examples:
    task_types[ex.task_type] += 1
print(f"  Task types: {dict(task_types)}")


In [None]:
# Validate examples
print("Validating examples...")
valid_examples = []
invalid_count = 0

for i, example in enumerate(examples):
    if i % 50 == 0:
        print(f"  Validated {i}/{len(examples)}...")
    
    # Extract package from context
    package = ""
    if example.context:
        match = re.search(r'package\s+(\S+)', example.context)
        if match:
            package = match.group(1)
    
    # Extract imports from context
    imports = []
    if example.context:
        import_matches = re.findall(r'import\s+([^\n]+)', example.context)
        imports = [imp.strip() for imp in import_matches]
    
    # Validate output code
    is_valid, formatted_code, error_msg = validate_rego_code(
        example.output_code,
        package=package,
        imports=imports
    )
    
    if is_valid:
        # Update with formatted code
        example.output_code = formatted_code
        valid_examples.append(example)
    else:
        invalid_count += 1
        if invalid_count <= 5:  # Show first 5 errors
            print(f"    Invalid example: {error_msg[:100]}")

print(f"✓ Validated: {len(valid_examples)} valid, {invalid_count} invalid")


In [None]:
# Split into train/eval
train_examples, eval_examples = split_train_eval(valid_examples, TRAIN_SPLIT)
print(f"✓ Split: {len(train_examples)} train, {len(eval_examples)} eval")

# Write to JSONL files
write_jsonl(train_examples, TRAIN_PATH)
write_jsonl(eval_examples, EVAL_PATH)
print(f"✓ Wrote {TRAIN_PATH}")
print(f"✓ Wrote {EVAL_PATH}")

# Create summary
summary = {
    "total_examples": len(valid_examples),
    "train_examples": len(train_examples),
    "eval_examples": len(eval_examples),
    "task_types": dict(task_types),
    "invalid_count": invalid_count
}
with open(DATASET_SUMMARY_PATH, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"✓ Wrote {DATASET_SUMMARY_PATH}")


## 3. Validate Dataset

Check dataset statistics and sample examples.


In [None]:
# Load and display summary
with open(DATASET_SUMMARY_PATH) as f:
    summary = json.load(f)

print("Dataset Summary:")
print(json.dumps(summary, indent=2))


In [None]:
# Sample a few examples
print("\nSample Training Examples:\n")
with open(TRAIN_PATH) as f:
    for i, line in enumerate(f):
        if i >= 3:  # Show first 3
            break
        example = json.loads(line)
        print(f"Example {i+1} ({example['task_type']}):")
        print(f"  Instruction: {example['instruction'][:100]}...")
        print(f"  Context length: {len(example.get('context', ''))} chars")
        print(f"  Output code length: {len(example['output_code'])} chars")
        print()


## 4. Prepare Training

Load tokenizer, create dataset class, and prepare data loaders.


In [None]:
# Load tokenizer
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"✓ Tokenizer loaded (vocab size: {len(tokenizer)})")


In [None]:
# System prompt
QWEN_SYSTEM_PROMPT = (
    "You are an expert Rego/OPA policy assistant. "
    "You follow instructions carefully and emit valid Rego code using "
    "Conforma's preferred patterns (deny contains result, METADATA, result_helper, etc). "
    "Only use helpers that are provided in the context - never invent new helper functions."
)

def build_messages_from_example(example):
    """Build chat messages from policy training example."""
    messages = [
        {"role": "system", "content": QWEN_SYSTEM_PROMPT}
    ]
    
    # Build user message
    user_parts = []
    
    if "context" in example:
        user_parts.append(example["context"])
    
    if "instruction" in example:
        user_parts.append("\n" + example["instruction"])
    
    if example.get("task_type") == "refactor" and "input_code" in example:
        user_parts.append("\n\nCode to refactor:\n```rego\n" + example["input_code"] + "\n```")
    
    user_content = "\n".join(user_parts)
    messages.append({"role": "user", "content": user_content})
    
    if "output_code" in example:
        messages.append({"role": "assistant", "content": example["output_code"]})
    
    return messages


In [None]:
# Dataset class
class PolicyDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        
        # Load examples
        with open(jsonl_path) as f:
            for line in f:
                self.examples.append(json.loads(line))
        
        # Pre-tokenize all examples
        print(f"Pre-tokenizing {len(self.examples)} examples...")
        self.tokenized = []
        for i, example in enumerate(self.examples):
            if i % 50 == 0:
                print(f"  Tokenized {i}/{len(self.examples)}...")
            
            messages = build_messages_from_example(example)
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            encoded = tokenizer(
                text,
                truncation=True,
                max_length=max_length,
                padding=False
            )
            
            self.tokenized.append({
                "input_ids": encoded["input_ids"],
                "attention_mask": encoded["attention_mask"]
            })
        
        print(f"✓ Pre-tokenization complete")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.tokenized[idx]

print("✓ Dataset class defined")


In [None]:
# Create datasets
print("Creating training dataset...")
train_dataset = PolicyDataset(TRAIN_PATH, tokenizer, max_length=MAX_SEQ_LEN)

print("\nCreating eval dataset...")
eval_dataset = PolicyDataset(EVAL_PATH, tokenizer, max_length=MAX_SEQ_LEN)

print(f"\n✓ Datasets ready:")
print(f"  Train: {len(train_dataset)} examples")
print(f"  Eval: {len(eval_dataset)} examples")
