# DAX Performance Testing

## Summary

This notebook is designed to measure DAX query timings under different cache states (cold, warm, and hot).

**Requirements:**

1. **DAX Queries from Excel or YAML**  
   - You must provide an Excel file containing the DAX queries in a table you wish to test.  
   - For each query, a column needs align with the `runQueryType` used for a given `queryId`.  
   - This notebook reads those queries and executes them on one or more Power BI/Fabric models.
   - For Yaml files please refer to the `sample-yaml-dax-query-file.yaml` file in the media folder


2. **Lakehouse Logging**  
   - You also must attach the appropriate Lakehouse in Fabric so that logs can be saved (both in a table and as files if you choose).  

3. **Capacity Pause/Resume**  
   - In some scenarios (e.g., simulating a "cold" cache on DirectQuery or Import models), the code pauses and resumes capacities.  
   - **Warning**: Pausing a capacity will interrupt any running workloads on that capacity. Resuming will take time and resources, and can affect other workspaces assigned to the same capacity.


### Install the latest Semantic Link Labs package

Check [here](https://pypi.org/project/semantic-link-labs/) to see the latest version.

In [None]:
%pip install semantic-link-labs

### Import the library and necessary packages

In [None]:
# Standard Library Imports
import time
import itertools
import functools
import builtins
import re
import xml.etree.ElementTree as ET
from threading import local
from typing import Any, Callable, Generator, Type
from contextlib import contextmanager
from uuid import uuid4
from datetime import datetime

# Third-party Imports
import pandas as pd
import requests
from pyspark.sql.functions import col, sum as _sum, when, countDistinct

# Local Application/Library-specific Imports
import sempy.fabric as fabric
import sempy_labs as labs
import yaml

### Global configurations & variables

In [None]:
# Generate a unique run ID for this test run
run_id = str(uuid4())
file_location = "Files/config/yaml/active/sample-yaml-dax-query-file.yaml"
file_type = file_location.split('.')[-1]

query_file_path = "Files/DAXQueries.xlsx"  # Path to the query file relative to the mount
query_file_mount_path = "/default"              # Mount location where the file is stored
query_worksheet_name = "DAXQueries"          # Worksheet name (for Excel files)

# Define models and their configurations for testing
models = [
    {
        "name": "Model Name", # The name of the semantic model
        "storageMode": "DirectLake",  # Import, DirectQuery, or DirectLake
        "cache_types": ["cold", "warm", "hot"], # List of cache types to be run (hot, warm, and cold)
        "model_workspace_name": "Model Workspace Name", # The workspace name of the semantic model
        "database_name": "Lakehouse Name",  # Only needed for cold cache queries for Import and DirectQuery
        "database_workspace_name": "Lakehouse Workspace Name",  # Only needed for cold cache queries for Import and DirectQuery
        "runQueryType": "query", # The name of the column in your DAX Excel file contains the query to be run
    },
]

# Only needed for cold cache queries for Import and DirectQuery
workspace_capacities = {
    "Workspace Name": {
        "capacity_name": "Testing Capacity Name",
        "alt_capacity_name": "Alternate Capacity Name",
    }
}

# Additional arguments controlling the behavior of query execution and logging
additional_arguments = {
    "roundNumber": 1, # The current round of DAX testing. Will be considered when determine if maxNumberPerQuery is met or not
    "onlyRunNewQueries": True, # Will determine if queries will stop being tested after maxNumberPerQuery is met
    "maxNumberPerQuery": 1, # The max number of queries to capture per round, queryId, model and cache type
    "maxFailuresBeforeSkipping": 5, # The number of failed query attempts per round, queryId, model and cache type before skipping
    "numberOfRunsPerQueryId": 15, # The number of times to loop over each queryId. If all combos have met maxNumberPerQuery, the loop will break
    "stopQueryIdsAt": 99, # Allows you to stop the queryId loop at a certain number, even if there are more queries present, i.e., there are queryIds 1-20 but stop at 5
    "forceStartQueriesAt1": False, # If set to False, testing will stop at the first incomplete queryId instead of starting at queryId 1  
    "logTableName": "DAXTestingLogTableName", # The name of the table in the attached lakehouse to save the performance logs to
    "clearAllLogs": False, # Will drop the existing logs table before starting testing
    "clearCurrentRoundLogs": False, # Will delete the logs associated with the current roundNumber before starting testing
    "randomizeRuns": True, # Will randomize the model and cache type combos when testing
    "skipSettingHotCache": False, # Should be False if randomizing the runs. If the runs are randomized, the previous warm cache run will set the hot cache
    "pauseAfterSettingCache": 5, # The number of seconds to wait after setting the cache
    "pauseAfterRunningQuery": 5, # The number of second to wait before writing the logs to the log table
    "pauseBetweenRuns": 30, # The number of seconds to wait before starting the next query
}

# Define the expected schema for DAX trace log events
event_schema = {
    "DirectQueryBegin": [
        "EventClass", "CurrentTime", "TextData", "StartTime", 
        "EndTime", "Duration", "CpuTime", "Success", "SessionID"
    ],
    "DirectQueryEnd": [
        "EventClass", "CurrentTime", "TextData", "StartTime", 
        "EndTime", "Duration", "CpuTime", "Success", "SessionID"
    ],
    "VertiPaqSEQueryBegin": [
        "EventClass", "EventSubclass", "CurrentTime", 
        "TextData", "StartTime", "SessionID"
    ],
    "VertiPaqSEQueryEnd": [
        "EventClass", "EventSubclass", "CurrentTime", "TextData", 
        "StartTime", "EndTime", "Duration", "CpuTime", "Success", "SessionID"
    ],
    "VertiPaqSEQueryCacheMatch": [
        "EventClass", "EventSubclass", "CurrentTime", "TextData", "SessionID"
    ],
    "QueryBegin": [
        "EventClass", "EventSubclass", "CurrentTime", "TextData", 
        "StartTime", "ConnectionID", "SessionID", "RequestProperties"
    ],
    "QueryEnd": [
        "EventClass", "EventSubclass", "CurrentTime", "TextData", 
        "StartTime", "EndTime", "Duration", "CpuTime", "Success", 
        "ConnectionID", "SessionID"
    ],
}

# Dictionary to track if a capacity pause is needed for each model during testing
model_pause_capacity_needed = {}

# Variables for Pausing/Resuming Capacities: credentials and configuration parameters for Azure Key Vault and resource management
resource_group_name = ""
subscription_id = ""
key_vault_uri_secret_name = ""
key_vault_client_id_secret_name = ""
key_vault_tenant_id_secret_name = ""
key_vault_client_secret_secret_name = ""

# Enforce case-sensitivity in Spark to ensure column name matching is exact
spark.conf.set("spark.sql.caseSensitive", True)

### Logging & Retry Decorators, Basic Helpers

In [None]:
# Thread-local storage variable to manage indentation for logging.
_thread_local = local()

@contextmanager
def dynamic_indented_print() -> Generator[None, None, None]:
    """
    A context manager that overrides the built-in print function to automatically
    indent messages based on the current call depth. This helps visualize nested logs.
    """
    original_print = builtins.print

    def custom_print(*args: Any, **kwargs: Any) -> None:
        # Use current call depth (default 0) to set indentation.
        depth = getattr(_thread_local, "call_depth", 0)
        indent = "    " * depth
        original_print(indent + " ".join(map(str, args)), **kwargs)

    builtins.print = custom_print
    try:
        yield
    finally:
        builtins.print = original_print

def log_function_calls(func: Callable) -> Callable:
    """
    Decorator that logs the start and end of a function call.
    It increases the indentation for nested function calls for better readability.
    """
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if not hasattr(_thread_local, "call_depth"):
            _thread_local.call_depth = 0

        with dynamic_indented_print():
            print(f"✅ {func.__name__} - Starting")
            _thread_local.call_depth += 1  # Increase indentation level for inner calls
            try:
                result = func(*args, **kwargs)
            finally:
                _thread_local.call_depth -= 1  # Decrease indentation level when done
                print(f"✅ {func.__name__} - Ending")
        return result

    return wrapper

def retry(exceptions: tuple[Type[Exception], ...],
          tries: int = 3,
          delay: int = 5,
          backoff: int = 2,
          logger: Callable = print) -> Callable:
    """
    A decorator for retrying a function call with exponential backoff.
    This is useful for transient errors (e.g., waiting for a capacity status update).
    """
    def decorator_retry(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper_retry(*args: Any, **kwargs: Any) -> Any:
            _tries, _delay = tries, delay
            while _tries > 1:
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    with dynamic_indented_print():
                        print(f"⚠️ {func.__name__} failed with {e}, retrying in {_delay} seconds...")
                    time.sleep(_delay)
                    _tries -= 1
                    _delay *= backoff  # Increase the delay for each retry
            return func(*args, **kwargs)
        return wrapper_retry
    return decorator_retry

## Helper Functions

In [None]:
def to_camel_case(s: str) -> str:
    """
    Converts a string to camel case.
    For example, 'Event Class' becomes 'eventClass'.
    """
    words = s.split()
    if not words:
        return s
    return words[0].lower() + ''.join(word.capitalize() for word in words[1:])

@log_function_calls
def trace_started(traces, trace_name: str) -> bool:
    """
    Checks if a specific trace (by name) has started by inspecting the traces DataFrame.
    """
    return traces.loc[traces["Name"] == trace_name].shape[0] > 0

@log_function_calls
@retry(Exception, tries=10, delay=2, backoff=2, logger=print)
def wait_for_trace_start(trace_connection, trace_name: str) -> bool:
    """
    Polls until the trace with the specified name is detected as started.
    Raises an exception if not started, triggering a retry.
    """
    if not trace_started(trace_connection.list_traces(), trace_name):
        raise Exception("Trace has not started yet")
    return True

def extract_session_id_from_query_begin(logs) -> str:
    """
    Extracts the Session ID from QueryBegin events that include 'SemPy' in their Request Properties.
    This is used to match corresponding QueryEnd events.
    """
    query_begin_events = logs[logs["Event Class"] == "QueryBegin"]
    for idx, row in query_begin_events.iterrows():
        req_props = row.get("Request Properties")
        if req_props and "SemPy" in req_props:
            try:
                root = ET.fromstring(req_props)
                for child in root:
                    if child.tag.endswith("SspropInitAppName") and child.text == "SemPy":
                        return row["Session ID"]
            except Exception as e:
                print(f"Error parsing Request Properties XML: {e}")
    return None

@log_function_calls
@retry(Exception, tries=60, delay=3, backoff=1, logger=print)
def wait_for_query_end_event(trace) -> tuple:
    """
    Waits until a QueryBegin event with 'SemPy' is detected and its corresponding QueryEnd event is present.
    Returns a tuple of (session_id, logs).
    """
    logs = trace.get_trace_logs()
    session_id = extract_session_id_from_query_begin(logs)
    if not session_id:
        raise Exception("QueryBegin event with SspropInitAppName 'SemPy' not found yet")
    
    query_end_events = logs[(logs["Event Class"] == "QueryEnd") & (logs["Session ID"] == session_id)]
    if query_end_events.empty:
        raise Exception("QueryEnd event for Session ID not collected yet")
    return session_id, logs

@log_function_calls
def collect_filtered_trace_logs(trace):
    """
    Waits for the QueryEnd event, stops the trace, and filters the logs for the relevant Session ID.
    This returns only the logs related to the current query execution.
    """
    session_id, _ = wait_for_query_end_event(trace)
    logs = trace.stop()
    filtered_logs = logs[logs["Session ID"] == session_id]
    return filtered_logs

@log_function_calls
@retry(Exception, tries=30, delay=5, backoff=1, logger=print)
def check_model_online(model: dict) -> None:
    """
    Checks if the given model is online by executing a simple DAX query.
    If the query fails, an exception is raised.
    """
    dax_query_eval_1(model)

@log_function_calls
def get_capacity_status(_capacity_name):
    """
    Retrieves the current status of a capacity by its name.
    """
    capacity_info = labs.list_capacities()
    return capacity_info.loc[
        capacity_info["Display Name"] == _capacity_name, "State"
    ].iloc[0]

@log_function_calls
@retry(Exception, tries=30, delay=2, backoff=2, logger=print)
def wait_for_capacity_status(_capacity_name, target_status):
    """
    Polls until the capacity status matches the target status.
    """
    current_status = get_capacity_status(_capacity_name)
    
    if current_status != target_status:
        raise Exception("Capacity status not updated yet")
    return current_status

def dax_query_eval_1(model: dict) -> None:
    """
    Executes a simple DAX query ("EVALUATE {1}") to verify connectivity and responsiveness.
    """
    try:
        fabric.evaluate_dax(
            model["name"], "EVALUATE {1}", workspace=model["model_workspace_name"]
        )
    except Exception as e:
        raise Exception("Failed to query model") from e

@log_function_calls
def wait_for_model_to_come_online(model: dict) -> None:
    """
    Waits until the model is online by calling check_model_online.
    Raises an exception if the model fails to come online.
    """
    try:
        check_model_online(model)
        print("✅ Model is online")
    except Exception as e:
        raise Exception("❌ Model failed to come online") from e

@log_function_calls
def load_dax_queries(file_path: str, mount_path: str, worksheet_name: str = None) -> pd.DataFrame:
    """
    Loads the DAX queries from the given file. Supports Excel and YAML formats.
    The first column must be 'queryId' and additional columns should contain variants of the DAX query.
    
    Args:
        file_path (str): Relative path to the query file.
        mount_path (str): The mount path where the file is stored.
        worksheet_name (str, optional): Worksheet name for Excel files.
        
    Returns:
        pd.DataFrame: DataFrame containing the DAX queries.
    """
    file_type = file_path.split('.')[-1].lower()
    full_path = f"{notebookutils.fs.getMountPath(mount_path)}/{file_path}"
    
    if file_type == 'xlsx':
        # Use the worksheet_name if provided, defaulting to the first sheet otherwise.
        return pd.read_excel(full_path, sheet_name=worksheet_name)
    elif file_type in ['yml', 'yaml']:
        with open(full_path, 'r') as f:
            data = yaml.load(f, Loader=yaml.FullLoader)
        return pd.DataFrame(data)
    else:
        raise ValueError(f"Unsupported file type: {file_type}")

### Pause & Resume Capacity

In [None]:
@log_function_calls
def update_model_pause_status(event: str, model: dict = None, workspace: str = None) -> None:
    """
    Updates the dictionary that tracks if a capacity pause is needed for each model.
    Different events trigger different updates:
      - "initialize": Set default based on storageMode.
      - "model_queried": Mark model as needing a pause for Import/DirectQuery.
      - "capacity_paused": Reset the pause flag after a capacity operation.
    """
    global model_pause_capacity_needed

    if event == "initialize":
        for m in models:
            model_pause_capacity_needed[m["name"]] = False if m["storageMode"] == "DirectLake" else True

    elif event == "model_queried" and model is not None:
        if model["storageMode"] == "Import":
            model_pause_capacity_needed[model["name"]] = True
        elif model["storageMode"] == "DirectQuery":
            target_db = model["database_name"]
            target_db_workspace = model["database_workspace_name"]
            for m in models:
                if (
                    m["storageMode"] == "DirectQuery" and 
                    m["database_name"] == target_db and 
                    m["database_workspace_name"] == target_db_workspace
                ):
                    model_pause_capacity_needed[m["name"]] = True

    elif event == "capacity_paused" and workspace is not None:
        # Reset pause flags after capacity operations complete
        for m in models:
            if m["model_workspace_name"] == workspace and m["storageMode"] == "Import":
                model_pause_capacity_needed[m["name"]] = False
        for m in models:
            if m["storageMode"] == "DirectQuery" and m["database_workspace_name"] == workspace:
                model_pause_capacity_needed[m["name"]] = False
    else:
        print(f"⚠️ Unknown event '{event}' or missing required parameter(s).")

    print("📝 Updated model_pause_capacity_needed")

@log_function_calls
@retry(Exception, tries=10, delay=5, backoff=2, logger=print)
def pause_resume_capacity(_capacity_name, _action, _simplify_logs=False):
    """
    Pauses or resumes a capacity based on the _action parameter.
    Uses Semantic Link Labs functions along with key vault authentication.
    """
    print(f"🔄 {_action.title()} capacity '{_capacity_name}': Attempting")

    # Check the current status of the capacity.
    current_status = get_capacity_status(_capacity_name)
    
    # Mapping for actions with expected and target statuses.
    action_options = {
        "pause": {
            "expected_status": "Active",  # Can only pause if capacity is active.
            "target_status": "Suspended",
            "action_function": labs.suspend_fabric_capacity,
        },
        "resume": {
            "expected_status": "Suspended",  # Can only resume if capacity is suspended.
            "target_status": "Active",
            "action_function": labs.resume_fabric_capacity,
        },
    }
    
    if current_status == action_options[_action]["expected_status"]:
        print(f"🛠️ {_action.title()} capacity '{_capacity_name}': Requesting action")
        
        # Authenticate using service principal details from the key vault.
        with labs.service_principal_authentication(
            key_vault_uri=key_vault_uri_secret_name,
            key_vault_tenant_id=key_vault_tenant_id_secret_name,
            key_vault_client_id=key_vault_client_id_secret_name,
            key_vault_client_secret=key_vault_client_secret_secret_name
        ):
            # Execute the pause or resume action.
            action_options[_action]["action_function"](
                capacity_name=_capacity_name,
                azure_subscription_id=subscription_id,
                resource_group=resource_group_name
            )
        
        # Wait for the capacity status to reach the target status.
        try:
            wait_for_capacity_status(_capacity_name, action_options[_action]["target_status"])
            print(f"✅ {_action.title()} capacity '{_capacity_name}': Action successful")
        except Exception as e:
            print(f"⚠️ {_action.title()} capacity '{_capacity_name}': Timeout waiting for target status. Error: {e}")
    else:
        print(f"ℹ️ {_action.title()} capacity '{_capacity_name}': Already '{current_status}'")

### Cache-setting functions

In [None]:
@log_function_calls
@retry(exceptions=(Exception,), tries=5, delay=5, backoff=2)
def clear_vertipaq_cache(model: dict) -> None:
    """
    Clears the VertiPaq cache by calling labs.clear_cache.
    Verifies that the cache clear succeeded by evaluating a simple DAX query.
    """
    print("🧹 Clearing VertiPaq cache")
    wait_for_model_to_come_online(model)
    try:
        labs.clear_cache(model["name"], workspace=model["model_workspace_name"])
        dax_query_eval_1(model)
        print("✅ Clear VertiPaq cache successful")
    except Exception as e:
        print("🔄 Clearing VertiPaq cache failed; retrying...")
        fabric.refresh_tom_cache(model["model_workspace_name"])
        raise e
    time.sleep(5)

@log_function_calls
def set_hot_cache(model: dict, expression: str, successful_query_count_goal: int = 1) -> bool:
    """
    Primes the hot cache by executing the given query multiple times.
    Returns True if the required number of successful executions is reached.
    """
    print("🔥 Setting Hot Cache")
    successful_query_count = 0
    # Define how many attempts to make based on the goal
    number_of_query_attempts = successful_query_count_goal * 5 if successful_query_count_goal > 1 else 1

    if additional_arguments["skipSettingHotCache"]:
        # Skip actual execution if flag is set
        successful_query_count = successful_query_count_goal
    else:
        for _ in range(number_of_query_attempts):
            # Clear any existing traces before running the query
            fabric.create_trace_connection(model["name"], model["model_workspace_name"]).drop_traces()
            trace_name = f"Cache Trace {uuid4()}"
            with fabric.create_trace_connection(model["name"], model["model_workspace_name"]) as trace_conn:
                with trace_conn.create_trace(event_schema, trace_name) as trace:
                    print("🔍 Starting trace for hot cache")
                    trace.start()
                    wait_for_trace_start(trace_conn, trace_name)
                    try:
                        print("⚡ Executing DAX query for hot cache")
                        fabric.evaluate_dax(model["name"], expression, workspace=model["model_workspace_name"])
                        successful_query_count += 1
                        print("✅ DAX query succeeded for hot cache")
                    except Exception as e:
                        print("❌ DAX query failed for hot cache:", e)
                    print("📜 Collecting trace logs for hot cache")
                    collect_filtered_trace_logs(trace)
            # Exit early if enough successful queries have been executed
            if successful_query_count >= successful_query_count_goal:
                break

    print(f"✅ Hot cache set; goal: {successful_query_count_goal} successful queries")
    return successful_query_count >= successful_query_count_goal

@log_function_calls
def set_warm_cache(model: dict, expression: str) -> bool:
    """
    Sets a warm cache by first priming the hot cache then clearing the VertiPaq cache.
    For DirectQuery models, it calls set_cold_cache.
    """
    print("🔥 Setting Warm Cache")
    if model["storageMode"] == "DirectQuery":
        set_cold_cache(model)
    hot_cache_successful = set_hot_cache(model, expression)
    clear_vertipaq_cache(model)
    print("✅ Warm cache set")
    return hot_cache_successful

@retry(exceptions=(Exception,), tries=5, delay=5, backoff=2)
def _refresh_tom_cache(workspace_name: str) -> None:
    """
    Refreshes the TOM cache for the specified workspace.
    """
    print(f"⌛ Refreshing TOM cache for workspace '{workspace_name}'")
    fabric.refresh_tom_cache(workspace_name)
    print("✅ TOM cache refreshed successfully")

@retry(exceptions=(Exception,), tries=30, delay=3, backoff=2)
def _wait_for_refresh_to_complete(model: dict, refresh_id: str) -> None:
    """
    Waits for a dataset refresh to complete or fail by checking its status.
    """
    status = fabric.get_refresh_execution_details(
        model["name"],
        refresh_id,
        workspace=model["model_workspace_name"],
    ).status

    if status not in ["Completed", "Failed"]:
        raise Exception(f"Refresh status is '{status}'; not done yet.")
    print(f"✅ Refresh status: '{status}' - finishing polling.")

@log_function_calls
def set_cold_cache(model: dict) -> None:
    """
    Sets a cold cache by ensuring that any previous cache is cleared.
    For non-DirectLake models, it pauses/resumes capacities to simulate a cold state.
    For DirectLake, it performs a clear and full refresh.
    """
    print("❄️ Setting Cold Cache")
    if model["storageMode"] != "DirectLake":
        if model_pause_capacity_needed.get(model["name"], False):
            ws_caps = workspace_capacities[model["model_workspace_name"]]
            print(f"🔄 Assigning alternate capacity for workspace '{model['model_workspace_name']}'")
            labs.assign_workspace_to_capacity(ws_caps["alt_capacity_name"], model["model_workspace_name"])
            print(f"✅ Alternate capacity assigned: {ws_caps['alt_capacity_name']}")

            pause_resume_capacity(ws_caps["capacity_name"], "pause")
            pause_resume_capacity(ws_caps["capacity_name"], "resume")

            print(f"🔄 Reassigning primary capacity for workspace '{model['model_workspace_name']}'")
            labs.assign_workspace_to_capacity(ws_caps["capacity_name"], model["model_workspace_name"])
            print(f"✅ Primary capacity assigned: {ws_caps['capacity_name']}")

            _refresh_tom_cache(model["model_workspace_name"])
            wait_for_model_to_come_online(model)
            update_model_pause_status("capacity_paused", workspace=model["model_workspace_name"])
            time.sleep(30)
            clear_vertipaq_cache(model)
    else:
        print("ℹ️ Performing clear refresh for cold cache")
        refresh_status_clear = fabric.refresh_dataset(
            model["name"],
            refresh_type="clearValues",
            workspace=model["model_workspace_name"],
        )
        _wait_for_refresh_to_complete(model, refresh_status_clear)
        print("✅ Clear refresh completed; performing full refresh")
        refresh_status_full = fabric.refresh_dataset(
            model["name"],
            refresh_type="full",
            workspace=model["model_workspace_name"],
        )
        _wait_for_refresh_to_complete(model, refresh_status_full)
        clear_vertipaq_cache(model)
    print("✅ Cold cache set")

### Log-table helpers & query checks

In [None]:
@log_function_calls
def get_log_table(table_name: str):
    """
    Retrieves the log table as a Spark DataFrame and applies filters based on the current round and event class.
    """
    try:
        raw_table = spark.table(table_name)
        base_filters = (
            (col("roundNumber") == additional_arguments["roundNumber"]) &
            (col("eventClass") == "QueryEnd")
        )
        return raw_table.filter(base_filters)
    except Exception:
        return None

@log_function_calls
def max_queries_met(check_logs: bool, log_table, model_cache_combo: dict, queryId: int) -> bool:
    """
    Determines if the maximum number of queries (successful or failed) has been met for the given model, cache type, and queryId.
    """
    if check_logs and log_table is not None:
        base_filters = (
            (col("modelName") == model_cache_combo["model"]["name"]) &
            (col("queryId") == queryId) &
            (col("cacheType") == model_cache_combo["cache_type"])
        )
        
        success_count = log_table.filter(base_filters & (col("success") == "Success")).count()
        failure_count = log_table.filter(base_filters & (col("success") == "Failure")).count()

        result = (success_count >= additional_arguments["maxNumberPerQuery"] or 
                  failure_count >= additional_arguments["maxFailuresBeforeSkipping"])
        print(f"📊 {'Skipping' if result else 'Continuing'} queries (Success: {success_count}, Failure: {failure_count})")
        return result
    else:
        return False

@log_function_calls
def get_starting_query_id(log_table, additional_args: dict) -> int:
    """
    Determines the next queryId to start from based on the existing logs.
    If no logs are present, starts at 1.
    """
    print("🔍 Determining starting query ID")
    if log_table is not None:
        success_failure_counts = log_table.groupBy("modelName", "cacheType", "queryId").agg(
            _sum(when(col("success") == "Success", 1).otherwise(0)).alias("success_count"),
            _sum(when(col("success") == "Failure", 1).otherwise(0)).alias("failure_count")
        )

        valid_queries = success_failure_counts.filter(
            (col("success_count") >= additional_args["maxNumberPerQuery"]) |
            (col("failure_count") >= additional_args["maxFailuresBeforeSkipping"])
        )

        distinct_combos_count = valid_queries.select("modelName", "cacheType").distinct().count()

        valid_query_ids = valid_queries.groupBy("queryId").agg(
            countDistinct("modelName", "cacheType").alias("valid_combinations")
        ).filter(col("valid_combinations") == distinct_combos_count)

        query_id_list = sorted([row["queryId"] for row in valid_query_ids.select("queryId").collect()])
        max_query_id = 0
        for i, qid in enumerate(query_id_list):
            if qid != i + 1:
                break
            max_query_id = qid

        starting_query_id = 1 if max_query_id == 0 else max_query_id + 1
    else:
        print(f"ℹ️ Log table {additional_args['logTableName']} does not exist")
        starting_query_id = 1

    print(f"✅ Starting query ID set to {starting_query_id}")
    return starting_query_id

### Main DAX testing orchestration functions

In [None]:
@log_function_calls
def run_dax_query_and_collect_logs(model_cache_combo: dict, dax_query: pd.Series, log_table) -> str:
    """
    Runs a single DAX query based on the provided model, cache type, and queryId.
    Handles cache setup, execution, trace logging, and appends the results to the log table.
    Returns "Ran" if the query executed or "Skipped" if conditions indicate to skip it.
    """
    model = model_cache_combo["model"]
    sent_dax_expression = dax_query[model["runQueryType"]]
    query_run_name = f"Model: {model['name']}, QueryId: {dax_query['queryId']}, Cache Type: {model_cache_combo['cache_type']}"
    valid_cache_type_for_model = model_cache_combo["cache_type"] in model["cache_types"]
    
    print(f"🚀 Starting query: {query_run_name}")
    
    # Check if this query should be run based on log history and cache type validity.
    if (not max_queries_met(additional_arguments["onlyRunNewQueries"], log_table, model_cache_combo, dax_query["queryId"])
        and valid_cache_type_for_model):
        set_cache_start_time = datetime.now().isoformat()
        wait_for_model_to_come_online(model)
        
        # Set the cache according to the type specified.
        if model_cache_combo["cache_type"] == "cold":
            set_cold_cache(model)
            cache_set = True
        elif model_cache_combo["cache_type"] == "warm":
            cache_set = set_warm_cache(model, sent_dax_expression)
        else:  # hot cache
            cache_set = set_hot_cache(model, sent_dax_expression)
        
        time.sleep(additional_arguments["pauseAfterSettingCache"])
        set_cache_end_time = datetime.now().isoformat()
        update_model_pause_status("model_queried", model=model)
        query_start_time = datetime.now().isoformat()
        
        # Start a trace for the DAX query execution.
        fabric.create_trace_connection(model["name"], model["model_workspace_name"]).drop_traces()
        trace_name = f"Simple DAX Trace {uuid4()}"
        
        with fabric.create_trace_connection(model["name"], model["model_workspace_name"]) as trace_conn:
            with trace_conn.create_trace(event_schema, trace_name) as trace:
                print("🔍 Starting trace for DAX query")
                trace.start()
                wait_for_trace_start(trace_conn, trace_name)
                
                dax_query_result = "Success"
                try:
                    print("⚡ Executing DAX query")
                    fabric.evaluate_dax(model["name"], sent_dax_expression, workspace=model["model_workspace_name"])
                    print("✅ DAX query executed successfully")
                except Exception as e:
                    dax_query_result = str(e)
                    print("❌ DAX query execution failed:", e)
                
                print("📜 Collecting trace logs")
                filtered_logs = collect_filtered_trace_logs(trace)
                
                time.sleep(additional_arguments["pauseAfterRunningQuery"])
                query_end_time = datetime.now().isoformat()

                # Capture request properties from the QueryBegin event.
                query_begin_rows = filtered_logs[filtered_logs["Event Class"] == "QueryBegin"]
                request_properties_value = query_begin_rows.iloc[0]["Request Properties"] if not query_begin_rows.empty else None
                filtered_logs["Request Properties"] = request_properties_value

                # Convert all column names to camelCase.
                filtered_logs.columns = [to_camel_case(col) for col in filtered_logs.columns]
                
                # Append metadata columns to the trace logs.
                filtered_logs = filtered_logs.assign(
                    sentExpression=sent_dax_expression,
                    setCacheStartTime=set_cache_start_time,
                    setCacheEndTime=set_cache_end_time,
                    queryStartTime=query_start_time,
                    queryEndTime=query_end_time,
                    modelName=model["name"],
                    queryId=dax_query["queryId"],
                    runQueryType=model["runQueryType"],
                    cacheType=model_cache_combo["cache_type"],
                    queryResult=dax_query_result,
                    storageMode=model["storageMode"],
                    roundNumber=additional_arguments["roundNumber"],
                )
                
                if not cache_set:
                    print("❌ Cache was not set properly; marking query as failed")
                    filtered_logs = filtered_logs.assign(Success="Failure")
        
        # Convert the pandas DataFrame to a Spark DataFrame and append to the log table.
        filtered_logs_df = spark.createDataFrame(filtered_logs)
        
        print("💾 Appending logs to table")
        filtered_logs_df.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(additional_arguments["logTableName"])
        print("✅ Logs appended to table")
        print("ℹ️ Pausing between runs")
        time.sleep(additional_arguments["pauseBetweenRuns"])
        
        return "Ran"
    else:
        print("⏭️ Query skipped (logs exist or invalid cache type)")
        return "Skipped"

@log_function_calls
def run_dax_queries() -> None:
    """
    Main entry point for running all DAX queries from the Excel file.
    Manages the log table, capacity checks, and iterates over all queries and their combinations.
    """
    print("🚀 Starting all DAX queries")

    # Load the DAX queries using the configuration parameters.
    dax_queries = load_dax_queries(query_file_path, query_file_mount_path, query_worksheet_name)
    
    # Initialize pause status flags for all models.
    update_model_pause_status("initialize")

    # Manage log table cleanup based on additional arguments.
    if additional_arguments["clearCurrentRoundLogs"]:
        print(f"🗑️ Dropping round {additional_arguments['roundNumber']} logs from {additional_arguments['logTableName']}")
        spark.sql(f"DELETE FROM {additional_arguments['logTableName']} WHERE roundNumber = {additional_arguments['roundNumber']}")
    if additional_arguments["clearAllLogs"]:
        print(f"🗑️ Dropping entire table {additional_arguments['logTableName']}")
        spark.sql(f"DROP TABLE IF EXISTS {additional_arguments['logTableName']}")
        startQueryIdsAt = 1
        log_table = None
    else:
        print(f"🔍 Retrieving table {additional_arguments['logTableName']}")
        log_table = get_log_table(additional_arguments["logTableName"])
        startQueryIdsAt = (1 if additional_arguments["clearCurrentRoundLogs"] or additional_arguments["forceStartQueriesAt1"] or log_table is None
                           else get_starting_query_id(log_table, additional_arguments))

    # Check if capacity pause/resume logic is required (for Import/DirectQuery with cold cache)
    include_pause_resume_logic = any(
        model["storageMode"] in ["Import", "DirectQuery"] and "cold" in model["cache_types"]
        for model in models
    )

    if include_pause_resume_logic:
        # Validate that all necessary workspaces are defined in the workspace_capacities
        for m in models:
            ws = m["model_workspace_name"]
            if ws not in workspace_capacities:
                raise ValueError(f"The workspace '{ws}' (in model '{m['name']}') is not defined in workspace_capacities.")
        # Resume capacities before starting queries
        for ws, caps in workspace_capacities.items():
            pause_resume_capacity(caps["capacity_name"], "resume", _simplify_logs=True)
            pause_resume_capacity(caps["alt_capacity_name"], "resume", _simplify_logs=True)
            print(f"✅ Assigning primary capacity '{caps['capacity_name']}' to workspace '{ws}'")
            labs.assign_workspace_to_capacity(caps["capacity_name"], ws)

    # Loop through each DAX query from the Excel file
    for _, dax_query in dax_queries.iterrows():
        if dax_query["queryId"] <= additional_arguments["stopQueryIdsAt"] and dax_query["queryId"] >= startQueryIdsAt:
            for _ in range(additional_arguments["numberOfRunsPerQueryId"]):
                total_query_count = 0
                skipped_query_count = 0
                # Randomize the run order if specified
                if additional_arguments["randomizeRuns"]:
                    print("🔀 Randomizing run order of (model, cache_type)")
                    model_cache_combo = (
                        pd.DataFrame(
                            itertools.product(models, ["cold", "warm", "hot"]),
                            columns=["model", "cache_type"],
                        )
                        .sample(frac=1)
                        .reset_index(drop=True)
                    )
                else:
                    df = pd.DataFrame(
                        itertools.product(models, ["cold", "warm", "hot"]),
                        columns=["model", "cache_type"],
                    )
                    df["model_name"] = df["model"].apply(lambda m: m["name"])
                    df["cache_order"] = pd.Categorical(df["cache_type"], categories=["cold", "warm", "hot"], ordered=True)
                    model_cache_combo = (
                        df.sort_values(by=["model_name", "cache_order"])
                        .drop(["model_name", "cache_order"], axis=1)
                        .reset_index(drop=True)
                    )

                # Execute the query for each model and cache type combination.
                for _, current_combo in model_cache_combo.iterrows():
                    total_query_count += 1
                    if include_pause_resume_logic:
                        update_model_pause_status("model_queried", model=current_combo["model"])
                    run_status = run_dax_query_and_collect_logs(current_combo, dax_query, log_table)
                    if run_status == "Skipped":
                        skipped_query_count += 1

                print(f"🔄 Refreshing log table from {additional_arguments['logTableName']}")
                log_table = get_log_table(additional_arguments["logTableName"])

                if total_query_count == skipped_query_count:
                    print("ℹ️ No new queries; skipping additional runs for this query group")
                    break

    if include_pause_resume_logic:
        for ws, caps in workspace_capacities.items():
            pause_resume_capacity(caps["alt_capacity_name"], "pause", _simplify_logs=True)
    print("✅ All queries complete")

### Execute main flow

In [None]:
run_dax_queries()

### Stop session

In [None]:
mssparkutils.session.stop()