# Genie Enhancement v3 - Debug Notebook

## 4-Stage Batch Apply Flow

This notebook tests the v3 enhancement workflow:

1. **Score** - Evaluate benchmarks on Genie Space
2. **Plan** - Analyze failures, generate ALL fixes
3. **Apply** - Apply ALL fixes in ONE batch update
4. **Validate** - Re-score and check improvement

## Key Difference from v2
- v2: Apply fixes one-at-a-time with rollback
- v3: Apply ALL fixes at once (batch)

## Usage
Run cells in order. Each section can be debugged independently.

## 1️⃣ Setup

In [None]:
# IMPORTANT: Clear cached modules to ensure latest code is loaded
# Run this cell first if you've updated the lib/ code
import sys

modules_to_remove = [m for m in sys.modules if m.startswith('lib')]
for m in modules_to_remove:
    del sys.modules[m]

print(f"Cleared {len(modules_to_remove)} cached lib modules")
print("Now run the rest of the notebook to use fresh imports")

In [None]:
# Project path setup
import sys
import os
from pathlib import Path

# Find project root
current_path = Path(os.getcwd())
if current_path.name == 'genie_enhancer':
    project_root = current_path
else:
    project_root = current_path
    while project_root.name != 'genie_enhancer' and project_root != project_root.parent:
        project_root = project_root.parent

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Configure logging for verbose output
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(name)s | %(levelname)s | %(message)s',
    datefmt='%H:%M:%S'
)

# Set verbose logging for lib modules
for module in ['lib.genie_client', 'lib.scorer', 'lib.llm', 'lib.enhancer', 'lib.applier', 'lib.space_api']:
    logging.getLogger(module).setLevel(logging.DEBUG)

print(f"Project root: {project_root}")
print(f"Logging level: DEBUG (verbose mode enabled)")

In [None]:
# Test imports from lib/
try:
    from lib.genie_client import GenieConversationalClient
    print("✅ lib.genie_client")
except Exception as e:
    print(f"❌ lib.genie_client: {e}")

try:
    from lib.space_api import SpaceUpdater
    print("✅ lib.space_api")
except Exception as e:
    print(f"❌ lib.space_api: {e}")

try:
    from lib.scorer import BenchmarkScorer
    print("✅ lib.scorer")
except Exception as e:
    print(f"❌ lib.scorer: {e}")

try:
    from lib.benchmark_parser import BenchmarkLoader
    print("✅ lib.benchmark_parser")
except Exception as e:
    print(f"❌ lib.benchmark_parser: {e}")

try:
    from lib.llm import DatabricksLLMClient
    print("✅ lib.llm")
except Exception as e:
    print(f"❌ lib.llm: {e}")

try:
    from lib.sql import SQLExecutor
    print("✅ lib.sql")
except Exception as e:
    print(f"❌ lib.sql: {e}")

try:
    from lib.enhancer import EnhancementPlanner
    print("✅ lib.enhancer")
except Exception as e:
    print(f"❌ lib.enhancer: {e}")

try:
    from lib.applier import BatchApplier
    print("✅ lib.applier")
except Exception as e:
    print(f"❌ lib.applier: {e}")

In [None]:
# Full imports
import json
import time
from datetime import datetime

from lib.genie_client import GenieConversationalClient
from lib.space_api import SpaceUpdater
from lib.scorer import BenchmarkScorer
from lib.benchmark_parser import BenchmarkLoader
from lib.llm import DatabricksLLMClient
from lib.sql import SQLExecutor
from lib.enhancer import EnhancementPlanner
from lib.applier import BatchApplier

print("✅ All imports successful")

## 2️⃣ Configuration

In [None]:
# === UPDATE THESE VALUES ===
DATABRICKS_HOST = "your-workspace.cloud.databricks.com"
DATABRICKS_TOKEN = "YOUR_TOKEN_HERE"
GENIE_SPACE_ID = "your-space-id"
WAREHOUSE_ID = "your-warehouse-id"  # For metric views
LLM_ENDPOINT = "databricks-claude-sonnet-4"

# Target score
TARGET_SCORE = 0.90

print(f"Host: {DATABRICKS_HOST}")
print(f"Space ID: {GENIE_SPACE_ID}")
print(f"Warehouse: {WAREHOUSE_ID}")
print(f"LLM: {LLM_ENDPOINT}")
print(f"Target: {TARGET_SCORE:.0%}")

## 3️⃣ Initialize Clients

In [None]:
# LLM Client (with rate limit protection)
print("Initializing LLM Client...")
llm_client = DatabricksLLMClient(
    host=DATABRICKS_HOST,
    token=DATABRICKS_TOKEN,
    endpoint_name=LLM_ENDPOINT,
    request_delay=10.0,          # 10s delay between requests
    rate_limit_base_delay=90.0   # 90s base delay on rate limit (90, 180, 360s...)
)

if llm_client.test_connection():
    print("✅ LLM Client connected")
    print("   - Request delay: 10s between calls")
    print("   - Rate limit retry: 90s base (exponential backoff)")
else:
    print("❌ LLM connection failed")

In [None]:
# LLM Client
print("Initializing LLM Client...")
llm_client = DatabricksLLMClient(
    host=DATABRICKS_HOST,
    token=DATABRICKS_TOKEN,
    endpoint_name=LLM_ENDPOINT
)

if llm_client.test_connection():
    print("✅ LLM Client connected")
else:
    print("❌ LLM connection failed")

In [None]:
# Space API (for export/import)
print("Initializing Space API...")
space_api = SpaceUpdater(
    host=DATABRICKS_HOST,
    token=DATABRICKS_TOKEN
)
print("✅ Space API initialized")

In [None]:
# SQL Executor (for metric views)
print("Initializing SQL Executor...")
sql_executor = SQLExecutor(
    host=DATABRICKS_HOST,
    token=DATABRICKS_TOKEN,
    warehouse_id=WAREHOUSE_ID
)
print("✅ SQL Executor initialized")

In [None]:
# Benchmark Scorer (verbose config)
print("Initializing Scorer...")
scorer = BenchmarkScorer(
    genie_client=genie_client,
    llm_client=llm_client,
    sql_executor=sql_executor,
    config={
        "question_timeout": 120,
        "question_delay": 3.0,      # Delay between questions
        "error_delay": 5.0,         # Extra delay after errors
        "parallel_workers": 0,      # 0 = sequential (easier to debug)
    }
)
print("✅ Scorer initialized")
print("   - Sequential mode (parallel_workers=0)")
print("   - Question delay: 3s")
print("   - Timeout: 120s")

## 4️⃣ Load Benchmarks

In [None]:
# Load benchmarks
benchmark_file = project_root / "benchmarks" / "benchmarks.json"
print(f"Loading from: {benchmark_file}")

loader = BenchmarkLoader(str(benchmark_file))
all_benchmarks = loader.load()
print(f"✅ Loaded {len(all_benchmarks)} benchmarks")

# Show first few
for i, b in enumerate(all_benchmarks[:3]):
    print(f"  {i+1}. {b['question'][:60]}...")

In [None]:
# Optional: Filter for faster testing
USE_SUBSET = True  # Set to False for full run

if USE_SUBSET:
    benchmarks = all_benchmarks[:5]  # First 5 only
    print(f"⚠️ TEST MODE: Using {len(benchmarks)} benchmarks")
else:
    benchmarks = all_benchmarks
    print(f"FULL MODE: Using {len(benchmarks)} benchmarks")

---
# STAGE 1: SCORE
---

In [None]:
# Run scoring (verbose output)
print("="*60)
print("STAGE 1: SCORING BENCHMARKS")
print("="*60)
print()

start_time = datetime.now()
score_results = scorer.score(benchmarks)
duration = (datetime.now() - start_time).total_seconds()

print()
print("="*60)
print("SCORING COMPLETE")
print("="*60)
print(f"Score: {score_results['score']:.1%}")
print(f"Passed: {score_results['passed']}/{score_results['total']}")
print(f"Failed: {score_results['failed']}")
print(f"Duration: {duration:.1f}s")
print(f"Avg per question: {duration/len(benchmarks):.1f}s")

In [None]:
# Show detailed results for each benchmark
print("="*60)
print("DETAILED RESULTS")
print("="*60)
print()

for i, r in enumerate(score_results['results'], 1):
    status = "✅ PASS" if r['passed'] else "❌ FAIL"
    print(f"[{i}/{len(score_results['results'])}] {status}")
    print(f"   Question: {r['question'][:70]}...")
    
    if r['passed']:
        print(f"   Genie SQL: {(r.get('genie_sql') or 'N/A')[:60]}...")
    else:
        print(f"   Category: {r.get('failure_category', 'unknown')}")
        print(f"   Reason: {(r.get('failure_reason') or 'N/A')[:80]}")
        if r.get('genie_sql'):
            print(f"   Genie SQL: {r['genie_sql'][:60]}...")
        if r.get('expected_sql'):
            print(f"   Expected:  {r['expected_sql'][:60]}...")
    
    print(f"   Response time: {r.get('response_time', 0):.1f}s")
    print()

# Summary
failed_results = [r for r in score_results['results'] if not r['passed']]
print("="*60)
print(f"SUMMARY: {len(failed_results)} failures to analyze")
print("="*60)

---
# STAGE 2: PLAN
---

In [None]:
# Get current space config
print("Exporting current space config...")
space_config = space_api.export_space(GENIE_SPACE_ID)
print(f"✅ Config loaded")
print(f"   Tables: {len(space_config.get('data_sources', {}).get('tables', []))}")

In [None]:
# Generate enhancement plan (verbose)
print("="*60)
print("STAGE 2: GENERATING ENHANCEMENT PLAN")
print("="*60)
print()
print(f"Analyzing {len(failed_results)} failures...")
print(f"Categories: metric_view, metadata, sample_query, instruction")
print(f"Parallel workers: 1 (sequential to avoid rate limits)")
print()

plan_start = datetime.now()
grouped_fixes = planner.generate_plan(
    failed_benchmarks=failed_results,
    space_config=space_config,
    parallel_workers=1  # Sequential to avoid rate limits
)
plan_duration = (datetime.now() - plan_start).total_seconds()

total_fixes = sum(len(f) for f in grouped_fixes.values())
print()
print("="*60)
print("PLAN GENERATION COMPLETE")
print("="*60)
print(f"Total fixes generated: {total_fixes}")
print(f"Duration: {plan_duration:.1f}s")

In [None]:
# Generate enhancement plan (verbose)
print("="*60)
print("STAGE 2: GENERATING ENHANCEMENT PLAN")
print("="*60)
print()
print(f"Analyzing {len(failed_results)} failures...")
print(f"Categories: metric_view, metadata, sample_query, instruction")
print(f"Parallel workers: 2")
print()

plan_start = datetime.now()
grouped_fixes = planner.generate_plan(
    failed_benchmarks=failed_results,
    space_config=space_config,
    parallel_workers=2  # Reduce for debugging
)
plan_duration = (datetime.now() - plan_start).total_seconds()

total_fixes = sum(len(f) for f in grouped_fixes.values())
print()
print("="*60)
print("PLAN GENERATION COMPLETE")
print("="*60)
print(f"Total fixes generated: {total_fixes}")
print(f"Duration: {plan_duration:.1f}s")

In [None]:
# Show all fixes by category (verbose)
print("="*60)
print("GENERATED FIXES BY CATEGORY")
print("="*60)

for category in ["metric_view", "metadata", "sample_query", "instruction"]:
    fixes = grouped_fixes.get(category, [])
    print(f"\n{'='*40}")
    print(f"{category.upper()} ({len(fixes)} fixes)")
    print("="*40)
    
    if not fixes:
        print("  (none)")
        continue
    
    for i, fix in enumerate(fixes, 1):
        fix_type = fix.get('type', 'unknown')
        print(f"\n  [{i}] {fix_type}")
        
        # Show fix-specific details
        if fix_type == 'add_synonym':
            print(f"      Table: {fix.get('table')}")
            print(f"      Column: {fix.get('column')}")
            print(f"      Synonym: '{fix.get('synonym')}'")
        elif fix_type == 'delete_synonym':
            print(f"      Table: {fix.get('table')}")
            print(f"      Column: {fix.get('column')}")
            print(f"      Remove: '{fix.get('synonym')}'")
        elif fix_type == 'add_column_description':
            print(f"      Table: {fix.get('table')}")
            print(f"      Column: {fix.get('column')}")
            print(f"      Description: {(fix.get('description') or '')[:60]}...")
        elif fix_type == 'add_table_description':
            print(f"      Table: {fix.get('table')}")
            print(f"      Description: {(fix.get('description') or '')[:60]}...")
        elif fix_type == 'add_example_query':
            print(f"      Pattern: {fix.get('pattern_name')}")
            print(f"      Question: {(fix.get('question') or '')[:50]}...")
            print(f"      SQL: {(fix.get('sql') or '')[:50]}...")
        elif fix_type == 'create_metric_view':
            print(f"      View: {fix.get('catalog')}.{fix.get('schema')}.{fix.get('metric_view_name')}")
            print(f"      SQL: {(fix.get('sql') or '')[:60]}...")
        elif fix_type == 'update_text_instruction':
            print(f"      Text: {(fix.get('instruction_text') or '')[:80]}...")
        
        # Show source failure
        source = fix.get('source_failure', {})
        if source:
            print(f"      Source: {source.get('question', '')[:40]}...")

print()
print("="*60)
print(f"TOTAL: {total_fixes} fixes ready to apply")
print("="*60)

---
# STAGE 3: APPLY (Batch)
---

In [None]:
# Initialize Batch Applier
print("Initializing Batch Applier...")
applier = BatchApplier(
    space_api=space_api,
    sql_executor=sql_executor,
    config={
        "catalog": "sandbox",
        "schema": "genie_enhancement"
    }
)
print("✅ Applier initialized")

In [None]:
# DRY RUN first (verbose)
DRY_RUN = True  # Set to False to actually apply

print("="*60)
print(f"STAGE 3: APPLY ALL FIXES {'(DRY RUN)' if DRY_RUN else '(LIVE)'}")
print("="*60)
print()

if DRY_RUN:
    print("DRY RUN MODE: Changes will be simulated, not applied")
else:
    print("LIVE MODE: Changes WILL be applied to the Genie Space!")
print()

apply_start = datetime.now()
apply_result = applier.apply_all(
    space_id=GENIE_SPACE_ID,
    grouped_fixes=grouped_fixes,
    dry_run=DRY_RUN
)
apply_duration = (datetime.now() - apply_start).total_seconds()

print()
print("="*60)
print("APPLY COMPLETE")
print("="*60)
print(f"Applied: {len(apply_result['applied'])}")
print(f"Failed: {len(apply_result['failed'])}")
print(f"Duration: {apply_duration:.1f}s")

In [None]:
# Show applied fixes
print("\n✅ Applied Fixes:")
for i, fix in enumerate(apply_result['applied'][:10], 1):
    print(f"  {i}. {fix.get('type')}")

if apply_result['failed']:
    print("\n❌ Failed Fixes:")
    for i, fix in enumerate(apply_result['failed'], 1):
        print(f"  {i}. {fix.get('type')}: {fix.get('error')}")

In [None]:
# LIVE RUN (uncomment to execute)
# WARNING: This will modify your Genie Space!

# print("Applying fixes for real...")
# apply_result = applier.apply_all(
#     space_id=GENIE_SPACE_ID,
#     grouped_fixes=grouped_fixes,
#     dry_run=False
# )
# print(f"Applied: {len(apply_result['applied'])}")

---
# STAGE 4: VALIDATE
---

In [None]:
# Wait for Genie indexing (only if not dry run)
INDEXING_WAIT = 60  # seconds

if not DRY_RUN and len(apply_result['applied']) > 0:
    print(f"Waiting {INDEXING_WAIT}s for Genie indexing...")
    time.sleep(INDEXING_WAIT)
    print("✅ Wait complete")
else:
    print("Skipping wait (dry run or no changes)")

In [None]:
# Re-score benchmarks (verbose)
print("="*60)
print("STAGE 4: VALIDATING RESULTS")
print("="*60)
print()
print("Re-scoring all benchmarks...")
print()

validate_start = datetime.now()
final_results = scorer.score(benchmarks)
validate_duration = (datetime.now() - validate_start).total_seconds()

initial_score = score_results['score']
final_score = final_results['score']
improvement = final_score - initial_score

print()
print("="*60)
print("VALIDATION COMPLETE")
print("="*60)
print()
print(f"Initial Score:  {initial_score:.1%} ({score_results['passed']}/{score_results['total']} passed)")
print(f"Final Score:    {final_score:.1%} ({final_results['passed']}/{final_results['total']} passed)")
print(f"Improvement:    {improvement:+.1%}")
print(f"Target:         {TARGET_SCORE:.1%}")
print()
print(f"Validation duration: {validate_duration:.1f}s")
print()

if final_score >= TARGET_SCORE:
    print("="*60)
    print("TARGET REACHED!")
    print("="*60)
elif improvement > 0:
    print("="*60)
    print(f"IMPROVED but need another loop (gap: {TARGET_SCORE - final_score:.1%})")
    print("="*60)
else:
    print("="*60)
    print("NO IMPROVEMENT - check fix quality")
    print("="*60)

---
# Debug Utilities
---

In [None]:
# Test Genie API directly
test_question = "What tables are available?"
print(f"Testing Genie: {test_question}")

response = genie_client.ask(test_question, timeout=60)
print(f"Status: {response['status']}")
if response.get('sql'):
    print(f"SQL: {response['sql'][:100]}...")

In [None]:
# Test LLM directly
test_prompt = "Say 'Hello, Genie Enhancement is working!'"
print(f"Testing LLM...")

response = llm_client.generate(test_prompt, max_tokens=50)
print(f"Response: {response}")

In [None]:
# Export current config to JSON
output_file = "debug_space_config.json"
with open(output_file, 'w') as f:
    json.dump(space_config, f, indent=2)
print(f"✅ Config saved to {output_file}")

In [None]:
# Export fixes to JSON
output_file = "debug_fixes.json"
with open(output_file, 'w') as f:
    json.dump(grouped_fixes, f, indent=2, default=str)
print(f"✅ Fixes saved to {output_file}")