# Rule Generation Pipeline

Generate intent rules from uDOM snapshots using HuggingFace models.

In [2]:
import sys
import os
from pathlib import Path
import json
from datetime import datetime
from collections import Counter

sys.path.append(str(Path('../utils/loaders').resolve()))
sys.path.append(str(Path('../utils/validators').resolve()))
sys.path.append(str(Path('../evaluation').resolve()))

from prompt_loader import PromptLoader
from snapshot_loader import SnapshotLoader
from schema_validator import SchemaValidator
from quality_metrics import QualityMetrics

HF_TOKEN = os.environ.get('HF_TOKEN')
print(f"HF_TOKEN: {'loaded' if HF_TOKEN else 'not found'}")

ModuleNotFoundError: No module named 'requests'

In [None]:
# Optional: Set HF_TOKEN directly here if not in environment
# Uncomment and add your token:
# os.environ['HF_TOKEN'] = 'your_huggingface_token_here'

# Verify token is accessible
if not HF_TOKEN:
    print("[WARNING] HF_TOKEN not set. You can:")
    print("   1. Set it in the cell above")
    print("   2. Export in terminal: export HF_TOKEN='your_token'")
    print("   3. Create .env file in rl/ or taste/ directory")
else:
    print(f"[OK] HF_TOKEN ready (length: {len(HF_TOKEN)})")


## Configuration

In [None]:
config_path = Path('../config/training_config.json')
with open(config_path, 'r') as f:
    config = json.load(f)

MODEL_ID = config['model_config']['base_model']
print(f"Model: {MODEL_ID}")
print(f"Alternatives: {config['model_config']['alternatives'][:2]}")

## Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

print(f"Loading model: {MODEL_ID}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map='auto' if torch.cuda.is_available() else None
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded on {model.device}")

## Load Snapshots

In [None]:
snapshot_loader = SnapshotLoader()
stats = snapshot_loader.get_stats()
print(f"Storage: {stats.get('storage', 'unknown')}")
print(f"Total snapshots: {stats.get('total_snapshots', 0)}")

In [None]:
snapshots = snapshot_loader.load_recent(limit=10)
print(f"Loaded {len(snapshots)} snapshots")
for i, s in enumerate(snapshots[:3]):
    sid = s.get('metadata', {}).get('snapshot_id', 'unknown')[:12]
    print(f"  {i+1}. {sid}...")

## Load Generator Prompt

In [None]:
prompt_loader = PromptLoader()
generator_prompt = prompt_loader.load_prompt('generator')
generator_text = generator_prompt['prompt_text']
print(f"Prompt: {generator_prompt.get('name')} v{generator_prompt.get('version')}")
print(f"Length: {len(generator_text):,} chars")

## Transform Snapshot

In [None]:
def transform_snapshot(snapshot: dict) -> dict:
    elements = snapshot.get('elements', [])
    element_types = set(e.get('type', 'unknown') for e in elements)
    property_types = set()
    for e in elements:
        property_types.update(e.get('properties', {}).keys())
    
    provenance = snapshot.get('observations', {}).get('provenance', {})
    return {
        'artifacts': {'before': snapshot, 'after': snapshot},
        'platform_semantics': {
            'element_types': list(element_types),
            'property_types': list(property_types)[:10]
        },
        'platform_metadata': {
            'platform': provenance.get('tool', 'figma'),
            'extraction_method': provenance.get('extraction_method', 'plugin_api')
        },
        'generation_config': {'min_confidence': 0.6, 'max_rules_per_batch': 5}
    }

In [None]:
if snapshots:
    generator_input = transform_snapshot(snapshots[0])
    print(f"Platform: {generator_input['platform_metadata']['platform']}")
    print(f"Element types: {generator_input['platform_semantics']['element_types']}")

## Generate Rules Function

In [None]:
def generate_rules(generator_input: dict, trace_id: str = None) -> dict:
    user_content = json.dumps(generator_input, indent=2)
    messages = [
        {'role': 'system', 'content': generator_text},
        {'role': 'user', 'content': f'Generate intent rules:\n\n{user_content}'}
    ]
    
    if hasattr(tokenizer, 'apply_chat_template'):
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        prompt = f"System: {generator_text}\n\nUser: Generate intent rules:\n\n{user_content}\n\nAssistant:"
    
    inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=2048)
    if torch.cuda.is_available():
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    try:
        json_start = response.find('{')
        json_end = response.rfind('}') + 1
        if json_start >= 0 and json_end > json_start:
            result = json.loads(response[json_start:json_end])
        else:
            result = {'intent_rules': [], 'raw_response': response}
    except json.JSONDecodeError:
        result = {'intent_rules': [], 'raw_response': response}
    
    result['trace_id'] = trace_id or f"trace_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    result['generated_at'] = datetime.now().isoformat()
    result['model'] = MODEL_ID
    return result

## Test Generation

In [None]:
if snapshots:
    print("Generating rules...")
    rules_output = generate_rules(generator_input, 'test_trace')
    print(f"Generated {len(rules_output.get('intent_rules', []))} rules")
    print(f"Model: {rules_output.get('model')}")
    for i, rule in enumerate(rules_output.get('intent_rules', [])[:3]):
        print(f"  {i+1}. {rule.get('description', 'N/A')[:60]}...")
else:
    rules_output = None
    print("No snapshots available")

## Batch Processing

In [None]:
def process_batch(snapshots: list, batch_size: int = 3, output_path: Path = None) -> list:
    """Process multiple snapshots and generate rules"""
    if output_path is None:
        output_path = Path('../data/generated_rules/rules.jsonl')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    all_outputs = []
    total_rules = 0
    
    for i, snapshot in enumerate(snapshots[:batch_size]):
        sid = snapshot.get('metadata', {}).get('snapshot_id', f'snapshot_{i}')
        print(f"[{i+1}/{batch_size}] Processing {sid[:12]}...")
        try:
            gen_input = transform_snapshot(snapshot)
            result = generate_rules(gen_input, sid)
            num_rules = len(result.get('intent_rules', []))
            all_outputs.append(result)
            total_rules += num_rules
            
            with open(output_path, 'a') as f:
                f.write(json.dumps(result) + '\n')
            print(f"  Generated {num_rules} rules")
        except Exception as e:
            print(f"  Error: {e}")
    
    print(f"\nProcessed {len(all_outputs)} snapshots")
    print(f"Total rules generated: {total_rules}")
    return all_outputs

# Uncomment to run batch processing:
# batch_outputs = process_batch(snapshots, batch_size=3)

## Validate Rules

In [None]:
schema_validator = SchemaValidator()
quality_metrics = QualityMetrics()

if rules_output and rules_output.get('intent_rules'):
    rules = rules_output['intent_rules']
    valid_count = sum(1 for r in rules if schema_validator.validate_rule(r)[0])
    print(f"Validation: {valid_count}/{len(rules)} valid")
    
    metrics = quality_metrics.compute_rule_quality(rules)
    print(f"Avg confidence: {metrics.avg_confidence:.2f}")
    print(f"Quality score: {metrics.quality_score:.2f}")
else:
    print("No rules to validate")

## Load Existing Rules

In [None]:
rules_path = Path('../data/generated_rules/rules.jsonl')
if rules_path.exists():
    all_rules = []
    with open(rules_path, 'r') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                all_rules.extend(data.get('intent_rules', []))
    print(f"Loaded {len(all_rules)} rules from file")
else:
    print(f"No rules file at {rules_path}")

## Next Steps

- 04_dataset_building.ipynb: Generate preference pairs
- 05_dataset_analysis.ipynb: Analyze dataset