# Databricks Genie API Load Testing with MLflow Tracking

This notebook demonstrates best practices for calling Databricks Genie Space API:
- Exponential backoff with jitter
- Retry logic for 429 rate limit errors
- MLflow tracking for all API calls
- Concurrent request handling
- Configurable test scenarios


## 1. Setup and Configuration


In [None]:
import os
import time
import random
import json
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Any
import requests
import mlflow
import pandas as pd
from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
    before_sleep_log
)
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


In [None]:
# ============================================================
# GENIE API CONNECTION CONFIGURATION
# ============================================================
WORKSPACE_URL = "https://your-workspace.cloud.databricks.com"  # Update this
GENIE_SPACE_ID = "your-genie-space-id"  # Update this
API_TOKEN = dbutils.secrets.get(scope="your-scope", key="your-key")  # Update this

# ============================================================
# MLFLOW EXPERIMENT CONFIGURATION
# ============================================================
MLFLOW_EXPERIMENT_NAME = "/Shared/genie-load-test"  # Change this to your desired experiment path
# Examples:
# - "/Shared/genie-load-test"
# - "/Users/your.email@company.com/genie-experiments"
# - "/Teams/data-engineering/genie-api-tests"

# ============================================================
# API RETRY CONFIGURATION
# ============================================================
MAX_RETRIES = 5           # Maximum number of retry attempts for 429 errors
BASE_WAIT_TIME = 1        # Initial wait time in seconds (exponential backoff base)
MAX_WAIT_TIME = 60        # Maximum wait time between retries in seconds
JITTER_MAX = 2            # Maximum random jitter in seconds to prevent thundering herd
TIMEOUT = 300             # Maximum time in seconds to wait for a single query response


## 2. Load Test Scenario Configuration

**Configure your load test scenarios here by adjusting the parameters below.**


In [None]:
# ============================================================
# LOAD TEST SCENARIO CONFIGURATIONS
# ============================================================
# Easily adjust these parameters to create different load test scenarios

# Predefined Load Test Scenarios
LOAD_TEST_SCENARIOS = {
    "single_test": {
        "num_questions": 1,
        "target_duration": None,  # No spreading - immediate submission
        "max_workers": 1,
        "description": "Single question test for validation"
    },
    "light_load": {
        "num_questions": 10,
        "target_duration": 60,  # 10 questions over 60 seconds
        "max_workers": 3,
        "description": "Light load - 10 questions over 1 minute"
    },
    "moderate_load": {
        "num_questions": 20,
        "target_duration": 30,  # 20 questions over 30 seconds
        "max_workers": 5,
        "description": "Moderate load - 20 questions over 30 seconds"
    },
    "high_load": {
        "num_questions": 35,
        "target_duration": 30,  # 35 questions over 30 seconds
        "max_workers": 10,
        "description": "High load - 35 questions over 30 seconds"
    },
    "stress_test": {
        "num_questions": 50,
        "target_duration": 20,  # 50 questions over 20 seconds
        "max_workers": 15,
        "description": "Stress test - 50 questions over 20 seconds"
    },
    "burst_test": {
        "num_questions": 50,
        "target_duration": None,  # No spreading - all at once
        "max_workers": 20,
        "description": "Burst test - 50 questions submitted immediately"
    },
    "sustained_load": {
        "num_questions": 100,
        "target_duration": 120,  # 100 questions over 2 minutes
        "max_workers": 10,
        "description": "Sustained load - 100 questions over 2 minutes"
    }
}

# ============================================================
# ACTIVE SCENARIO SELECTION
# ============================================================
# Select which scenario to run (change this to switch scenarios)
ACTIVE_SCENARIO = "high_load"  # Options: single_test, light_load, moderate_load, high_load, stress_test, burst_test, sustained_load

# OR - Create a custom scenario with your own parameters
# Uncomment and modify the lines below to use custom parameters instead

CUSTOM_SCENARIO = {
    "num_questions": 35,        # Total number of questions to submit
    "target_duration": 30,      # Duration in seconds to spread questions over (None = submit all immediately)
    "max_workers": 10,          # Maximum concurrent threads/workers
    "description": "Custom load test scenario"
}

# Set to True to use CUSTOM_SCENARIO instead of ACTIVE_SCENARIO
USE_CUSTOM_SCENARIO = False

# ============================================================
# Helper function to get active scenario configuration
# ============================================================
def get_active_scenario():
    """Get the currently active scenario configuration"""
    if USE_CUSTOM_SCENARIO:
        scenario = CUSTOM_SCENARIO
        scenario_name = "custom"
    else:
        scenario = LOAD_TEST_SCENARIOS.get(ACTIVE_SCENARIO)
        scenario_name = ACTIVE_SCENARIO
        if scenario is None:
            raise ValueError(f"Unknown scenario: {ACTIVE_SCENARIO}. Available: {list(LOAD_TEST_SCENARIOS.keys())}")
    
    print(f"\n{'='*70}")
    print(f"ACTIVE LOAD TEST SCENARIO: {scenario_name.upper()}")
    print(f"{'='*70}")
    print(f"Description: {scenario['description']}")
    print(f"Number of Questions: {scenario['num_questions']}")
    print(f"Target Duration: {scenario['target_duration']} seconds" if scenario['target_duration'] else "Target Duration: None (immediate submission)")
    print(f"Max Workers: {scenario['max_workers']}")
    if scenario['target_duration']:
        rate = scenario['num_questions'] / scenario['target_duration']
        print(f"Expected Rate: {rate:.2f} questions/second")
    print(f"{'='*70}\n")
    
    return scenario_name, scenario

# Display active scenario
scenario_name, scenario_config = get_active_scenario()


## 3. Rate Limit Exception Classes


In [None]:
class RateLimitError(Exception):
    """Custom exception for rate limit errors (429)"""
    pass

class GenieAPIError(Exception):
    """Custom exception for Genie API errors"""
    pass


In [None]:
class GenieAPIClient:
    """
    Client for interacting with Databricks Genie Space API.
    Implements exponential backoff, jitter, and retry logic.
    """
    
    def __init__(
        self,
        workspace_url: str,
        space_id: str,
        token: str,
        max_retries: int = MAX_RETRIES,
        base_wait: int = BASE_WAIT_TIME,
        max_wait: int = MAX_WAIT_TIME
    ):
        self.workspace_url = workspace_url.rstrip('/')
        self.space_id = space_id
        self.token = token
        self.max_retries = max_retries
        self.base_wait = base_wait
        self.max_wait = max_wait
        self.base_url = f"{self.workspace_url}/api/2.0/genie/spaces/{self.space_id}"
        
    def _get_headers(self) -> Dict[str, str]:
        """Get request headers with authentication"""
        return {
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/json"
        }
    
    def _add_jitter(self, wait_time: float) -> float:
        """Add random jitter to wait time to prevent thundering herd"""
        jitter = random.uniform(0, JITTER_MAX)
        return wait_time + jitter
    
    @retry(
        retry=retry_if_exception_type(RateLimitError),
        stop=stop_after_attempt(MAX_RETRIES),
        wait=wait_exponential(multiplier=BASE_WAIT_TIME, max=MAX_WAIT_TIME),
        before_sleep=before_sleep_log(logger, logging.WARNING)
    )
    def _make_request(
        self,
        method: str,
        endpoint: str,
        data: Optional[Dict] = None,
        params: Optional[Dict] = None
    ) -> Dict:
        """Make HTTP request with retry logic for rate limiting"""
        url = f"{self.base_url}/{endpoint}"
        
        try:
            response = requests.request(
                method=method,
                url=url,
                headers=self._get_headers(),
                json=data,
                params=params,
                timeout=30
            )
            
            # Handle rate limiting
            if response.status_code == 429:
                retry_after = int(response.headers.get('Retry-After', self.base_wait))
                wait_time = self._add_jitter(retry_after)
                logger.warning(f"Rate limited. Waiting {wait_time:.2f} seconds before retry")
                time.sleep(wait_time)
                raise RateLimitError("Rate limit exceeded (429)")
            
            response.raise_for_status()
            return response.json()
            
        except requests.exceptions.HTTPError as e:
            logger.error(f"HTTP Error: {e}")
            raise GenieAPIError(f"API request failed: {e}")
        except requests.exceptions.RequestException as e:
            logger.error(f"Request Error: {e}")
            raise GenieAPIError(f"Request failed: {e}")
    
    def start_conversation(self, content: str) -> Dict:
        """Start a new conversation with a question"""
        data = {"content": content}
        return self._make_request("POST", "start-conversation", data=data)
    
    def get_message_query_result(self, conversation_id: str, message_id: str) -> Dict:
        """Get the query result for a message"""
        endpoint = f"conversations/{conversation_id}/messages/{message_id}/query-result"
        return self._make_request("GET", endpoint)
    
    def wait_for_result(
        self,
        conversation_id: str,
        message_id: str,
        timeout: int = TIMEOUT,
        poll_interval: int = 2
    ) -> Dict:
        """Poll for query result until completion or timeout"""
        start_time = time.time()
        
        while time.time() - start_time < timeout:
            try:
                result = self.get_message_query_result(conversation_id, message_id)
                status = result.get("status")
                
                if status == "COMPLETED":
                    return result
                elif status == "FAILED":
                    raise GenieAPIError(f"Query failed: {result.get('error', 'Unknown error')}")
                elif status in ["EXECUTING_QUERY", "QUERYING_HISTORY", "SUBMITTED"]:
                    time.sleep(poll_interval)
                else:
                    logger.warning(f"Unknown status: {status}")
                    time.sleep(poll_interval)
                    
            except RateLimitError:
                # The retry decorator will handle this
                raise
            except Exception as e:
                logger.error(f"Error polling for result: {e}")
                raise
        
        raise TimeoutError(f"Query timed out after {timeout} seconds")
    
    def ask_question(self, question: str, timeout: int = TIMEOUT) -> Dict:
        """Ask a question and wait for the result"""
        # Start conversation
        start_response = self.start_conversation(question)
        conversation_id = start_response["conversation_id"]
        message_id = start_response["message_id"]
        
        # Wait for result
        result = self.wait_for_result(conversation_id, message_id, timeout)
        
        return {
            "conversation_id": conversation_id,
            "message_id": message_id,
            "question": question,
            "result": result
        }

def log_question_to_mlflow(
    question: str,
    response: Dict,
    duration: float,
    error: Optional[str] = None,
    run_name: Optional[str] = None
) -> str:
    """
    Log a single question and its response to MLflow
    Returns the run_id
    """
    with mlflow.start_run(run_name=run_name, nested=True) as run:
        # Log parameters
        mlflow.log_param("question", question)
        mlflow.log_param("timestamp", datetime.now().isoformat())
        
        # Log metrics
        mlflow.log_metric("duration_seconds", duration)
        mlflow.log_metric("success", 1 if error is None else 0)
        
        if error:
            mlflow.log_param("error", error)
            mlflow.set_tag("status", "failed")
        else:
            mlflow.set_tag("status", "success")
            
            # Log conversation details
            mlflow.log_param("conversation_id", response.get("conversation_id"))
            mlflow.log_param("message_id", response.get("message_id"))
            
            # Log result details if available
            result = response.get("result", {})
            if "statement_response" in result:
                statement_response = result["statement_response"]
                if "result_data" in statement_response:
                    result_data = statement_response["result_data"]
                    mlflow.log_metric("row_count", result_data.get("row_count", 0))
            
            # Save full response as artifact
            with open("response.json", "w") as f:
                json.dump(response, f, indent=2)
            mlflow.log_artifact("response.json")
        
        return run.info.run_id


## 5. Load Test Orchestrator


In [None]:
class LoadTestOrchestrator:
    """
    Orchestrates load testing with concurrent requests and MLflow tracking
    """
    
    def __init__(self, client: GenieAPIClient, experiment_name: str = "/Shared/genie-load-test"):
        self.client = client
        self.experiment_name = experiment_name
        mlflow.set_experiment(self.experiment_name)
    
    def _execute_single_question(
        self,
        question: str,
        question_index: int,
        delay: float = 0
    ) -> Dict:
        """Execute a single question with optional delay"""
        if delay > 0:
            time.sleep(delay)
        
        start_time = time.time()
        error = None
        response = None
        
        try:
            logger.info(f"[Q{question_index}] Asking: {question[:50]}...")
            response = self.client.ask_question(question)
            duration = time.time() - start_time
            logger.info(f"[Q{question_index}] Completed in {duration:.2f}s")
            
        except Exception as e:
            duration = time.time() - start_time
            error = str(e)
            logger.error(f"[Q{question_index}] Failed after {duration:.2f}s: {error}")
        
        # Log to MLflow
        run_name = f"question_{question_index}"
        run_id = log_question_to_mlflow(question, response or {}, duration, error, run_name)
        
        return {
            "question_index": question_index,
            "question": question,
            "duration": duration,
            "success": error is None,
            "error": error,
            "response": response,
            "run_id": run_id
        }
    
    def run_load_test(
        self,
        questions: List[str],
        max_workers: int = 10,
        target_duration: Optional[float] = None,
        test_name: str = "load_test"
    ) -> pd.DataFrame:
        """
        Run load test with multiple questions
        
        Args:
            questions: List of questions to ask
            max_workers: Maximum number of concurrent threads
            target_duration: Target duration in seconds to spread questions over (e.g., 30 for "35 questions in 30 seconds")
            test_name: Name for the MLflow run
        
        Returns:
            DataFrame with test results
        """
        with mlflow.start_run(run_name=test_name) as parent_run:
            # Log test configuration
            mlflow.log_param("num_questions", len(questions))
            mlflow.log_param("max_workers", max_workers)
            mlflow.log_param("target_duration", target_duration or "None")
            mlflow.log_param("test_name", test_name)
            mlflow.log_param("start_time", datetime.now().isoformat())
            
            # Calculate delays if target duration is specified
            delays = [0] * len(questions)
            if target_duration and len(questions) > 1:
                interval = target_duration / len(questions)
                delays = [i * interval for i in range(len(questions))]
            
            # Execute questions concurrently
            results = []
            test_start_time = time.time()
            
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = {
                    executor.submit(
                        self._execute_single_question,
                        question,
                        i,
                        delays[i]
                    ): i
                    for i, question in enumerate(questions)
                }
                
                for future in as_completed(futures):
                    try:
                        result = future.result()
                        results.append(result)
                    except Exception as e:
                        logger.error(f"Unexpected error: {e}")
            
            total_duration = time.time() - test_start_time
            
            # Log summary metrics
            successful = sum(1 for r in results if r["success"])
            failed = len(results) - successful
            avg_duration = sum(r["duration"] for r in results) / len(results) if results else 0
            
            mlflow.log_metric("total_duration_seconds", total_duration)
            mlflow.log_metric("successful_questions", successful)
            mlflow.log_metric("failed_questions", failed)
            mlflow.log_metric("success_rate", successful / len(results) if results else 0)
            mlflow.log_metric("avg_question_duration", avg_duration)
            mlflow.log_metric("throughput_qps", len(results) / total_duration if total_duration > 0 else 0)
            
            # Create results DataFrame
            df = pd.DataFrame(results)
            
            # Save results
            results_file = "load_test_results.csv"
            df.to_csv(results_file, index=False)
            mlflow.log_artifact(results_file)
            
            logger.info(f"\n{'='*60}")
            logger.info(f"Load Test Summary: {test_name}")
            logger.info(f"{'='*60}")
            logger.info(f"Total Questions: {len(results)}")
            logger.info(f"Successful: {successful}")
            logger.info(f"Failed: {failed}")
            logger.info(f"Success Rate: {successful / len(results) * 100:.2f}%")
            logger.info(f"Total Duration: {total_duration:.2f}s")
            logger.info(f"Avg Question Duration: {avg_duration:.2f}s")
            logger.info(f"Throughput: {len(results) / total_duration:.2f} QPS")
            logger.info(f"{'='*60}\n")
            
            return df


## 6. Test Questions


In [None]:
# Sample questions for testing
SAMPLE_QUESTIONS = [
    "What is the total sales for last month?",
    "Show me top 10 customers by revenue",
    "What are the sales trends over the last 6 months?",
    "Which products have the highest profit margin?",
    "What is the average order value?",
    "Show me sales by region",
    "What is the customer churn rate?",
    "How many new customers joined last month?",
    "What is the inventory level for top products?",
    "Show me the sales forecast for next quarter"
]

# Generate more questions for load testing
def generate_test_questions(base_questions: List[str], count: int) -> List[str]:
    """Generate test questions by repeating and varying base questions"""
    questions = []
    for i in range(count):
        base_question = base_questions[i % len(base_questions)]
        questions.append(f"{base_question} (iteration {i // len(base_questions) + 1})")
    return questions


## 7. Initialize Client and Orchestrator


In [None]:
# Initialize the Genie API client
client = GenieAPIClient(
    workspace_url=WORKSPACE_URL,
    space_id=GENIE_SPACE_ID,
    token=API_TOKEN
)

# Initialize the load test orchestrator with configured MLflow experiment
orchestrator = LoadTestOrchestrator(
    client=client,
    experiment_name=MLFLOW_EXPERIMENT_NAME
)

print("✓ Client and orchestrator initialized successfully")
print(f"✓ MLflow Experiment: {MLFLOW_EXPERIMENT_NAME}")


## 8. Run Load Test with Active Scenario

**This cell runs the load test using the active scenario configured in Section 2.**


### Run Test with Active Scenario


In [None]:
# Get the active scenario configuration from Section 2
scenario_name, config = get_active_scenario()

# Generate test questions based on configured number
test_questions = generate_test_questions(SAMPLE_QUESTIONS, config['num_questions'])

# Run the load test with the configured parameters
results_df = orchestrator.run_load_test(
    questions=test_questions,
    max_workers=config['max_workers'],
    target_duration=config['target_duration'],
    test_name=f"load_test_{scenario_name}"
)

# Display results summary
display(results_df[['question_index', 'question', 'duration', 'success', 'error']].head(20))


## 9. Analyze Results


In [None]:
def analyze_results(df: pd.DataFrame):
    """Analyze and visualize test results"""
    print("\n" + "="*60)
    print("DETAILED ANALYSIS")
    print("="*60)
    
    # Success rate
    success_rate = df['success'].mean() * 100
    print(f"\nSuccess Rate: {success_rate:.2f}%")
    
    # Duration statistics
    print("\nDuration Statistics (seconds):")
    print(f"  Min: {df['duration'].min():.2f}")
    print(f"  Max: {df['duration'].max():.2f}")
    print(f"  Mean: {df['duration'].mean():.2f}")
    print(f"  Median: {df['duration'].median():.2f}")
    print(f"  Std Dev: {df['duration'].std():.2f}")
    
    # Percentiles
    print("\nDuration Percentiles (seconds):")
    print(f"  P50: {df['duration'].quantile(0.50):.2f}")
    print(f"  P75: {df['duration'].quantile(0.75):.2f}")
    print(f"  P90: {df['duration'].quantile(0.90):.2f}")
    print(f"  P95: {df['duration'].quantile(0.95):.2f}")
    print(f"  P99: {df['duration'].quantile(0.99):.2f}")
    
    # Error analysis
    if not df['success'].all():
        print("\nError Summary:")
        error_counts = df[df['error'].notna()]['error'].value_counts()
        for error, count in error_counts.items():
            print(f"  {error}: {count}")
    
    print("\n" + "="*60 + "\n")
    
    return df

# Analyze the test results
analyzed_df = analyze_results(results_df)


## 10. Run Additional Custom Test (Optional)

**You can run an additional test with different parameters without changing the main configuration.**


In [None]:
# Example: Run an additional test with custom parameters
# Modify these values as needed
CUSTOM_NUM_QUESTIONS = 20
CUSTOM_TARGET_DURATION = 60  # seconds (None = immediate submission)
CUSTOM_MAX_WORKERS = 5

# Generate and run custom test
custom_questions = generate_test_questions(SAMPLE_QUESTIONS, CUSTOM_NUM_QUESTIONS)
custom_results_df = orchestrator.run_load_test(
    questions=custom_questions,
    max_workers=CUSTOM_MAX_WORKERS,
    target_duration=CUSTOM_TARGET_DURATION,
    test_name="custom_adhoc_test"
)

# Analyze custom results
analyze_results(custom_results_df)


## 11. View MLflow Tracking UI

To view the detailed logs and metrics in MLflow:

1. Navigate to your Databricks workspace
2. Go to the MLflow Experiments page
3. Find the experiment with the name configured in `MLFLOW_EXPERIMENT_NAME` (Section 1)
   - Default: `/Shared/genie-load-test`
   - You can change this in Section 1 to organize your experiments
4. Click on any run to see detailed metrics, parameters, and artifacts

Each question execution is logged as a nested run with:
- Question text
- Duration
- Success/failure status
- Full API response (as artifact)
- Conversation and message IDs
- Row counts (if applicable)

**Tip**: Use different experiment names for different testing purposes:
- `/Shared/genie-load-test-prod` - Production load tests
- `/Shared/genie-load-test-dev` - Development tests
- `/Users/your.email@company.com/genie-experiments` - Personal experiments
